diff --git a/BUILD.bazel b/BUILD.bazel index eb7cf72..01c4052 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -85,6 +85,9 @@ test_suite( cc_library( name = "ops", + srcs = [ + "ops/matmul.cc", + ], hdrs = [ "ops/matmul.h", "ops/ops.h", @@ -103,12 +106,14 @@ cc_library( ":threading", "//compression:compress", "@highway//:algo", + "@highway//:bit_set", "@highway//:hwy", "@highway//:math", "@highway//:matvec", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", + "@highway//:topology", ], ) @@ -126,6 +131,7 @@ cc_test( ":test_util", ":threading", "@googletest//:gtest_main", # buildcleaner: keep + "//:app", "//compression:compress", "//compression:test_util", "@highway//:hwy", @@ -151,6 +157,7 @@ cc_test( ":ops", ":test_util", "@googletest//:gtest_main", # buildcleaner: keep + "//:app", "//compression:compress", "@highway//:hwy", "@highway//:hwy_test_util", @@ -176,26 +183,6 @@ cc_test( ], ) -cc_test( - name = "matmul_unit_test", - size = "small", - timeout = "long", - srcs = ["ops/matmul_unit_test.cc"], - local_defines = ["HWY_IS_TEST"], - # for test_suite. - tags = ["ops_tests"], - deps = [ - ":allocator", - ":basics", - ":ops", - ":test_util", - "@googletest//:gtest_main", # buildcleaner: keep - "//compression:compress", - "@highway//:hwy", - "@highway//:hwy_test_util", - ], -) - cc_test( name = "matmul_test", size = "small", @@ -652,6 +639,7 @@ cc_test( ":sampler", ":weights", "@googletest//:gtest_main", + "//:threading", "//compression:compress", "@highway//:hwy", "@highway//:hwy_test_util", diff --git a/CMakeLists.txt b/CMakeLists.txt index 70ce270..3b3f8b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,8 @@ set(SOURCES gemma/weights.h ops/dot-inl.h ops/matmul-inl.h + ops/matmul.cc + ops/matmul.h ops/matvec-inl.h ops/ops-inl.h ops/ops.h @@ -168,7 +170,6 @@ set(GEMMA_TEST_FILES ops/dot_test.cc ops/gemma_matvec_test.cc ops/matmul_test.cc - ops/matmul_unit_test.cc ops/ops_test.cc paligemma/image_test.cc paligemma/paligemma_test.cc diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index c5671c7..974ea6e 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -33,6 +33,7 @@ #include "backprop/test_util.h" #include "gemma/configs.h" #include "ops/ops.h" +#include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -58,7 +59,9 @@ void TestMatMulVJP() { static const size_t kRows = 8; static const size_t kCols = 64; static const size_t kTokens = 5; - hwy::ThreadPool pool(8); + gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), + BoundedSlice(0, 8)); + Allocator::Init(pools.Topology()); std::mt19937 gen(42); MatStorageT weights("weights", kRows, kCols); MatStorageT x("x", kTokens, kCols); @@ -85,7 +88,7 @@ void TestMatMulVJP() { grad.ZeroInit(); MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens, - grad.data(), dx.data(), pool); + grad.data(), dx.data(), pools.Pool()); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -102,7 +105,9 @@ void TestMultiHeadMatMulVJP() { static const size_t kCols = 16; static const size_t kHeads = 4; static const size_t kTokens = 3; - hwy::ThreadPool pool(8); + gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), + BoundedSlice(0, 8)); + Allocator::Init(pools.Topology()); std::mt19937 gen(42); MatStorageT weights("weights", kRows, kCols * kHeads); MatStorageT x("x", kTokens, kCols * kHeads); @@ -130,7 +135,7 @@ void TestMultiHeadMatMulVJP() { grad.ZeroInit(); MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols, - kRows, kTokens, grad.data(), dx.data(), pool); + kRows, kTokens, grad.data(), dx.data(), pools.Pool()); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -145,7 +150,9 @@ void TestMultiHeadMatMulVJP() { void TestRMSNormVJP() { static const size_t K = 2; static const size_t N = 64; - hwy::ThreadPool pool(8); + gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), + BoundedSlice(0, 8)); + Allocator::Init(pools.Topology()); std::mt19937 gen(42); MatStorageT weights("weights", N, 1); MatStorageT x("x", K, N); @@ -172,7 +179,7 @@ void TestRMSNormVJP() { grad.ZeroInit(); RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), - dx.data(), pool); + dx.data(), pools.Pool()); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -209,7 +216,9 @@ static ModelConfig TestConfig() { void TestEndToEnd() { std::mt19937 gen(42); - hwy::ThreadPool pool(0); + gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), + BoundedSlice(0, 1)); + Allocator::Init(pools.Topology()); ModelConfig config = TestConfig(); WeightsWrapper weights(config); WeightsWrapper grad(config); @@ -234,13 +243,13 @@ void TestEndToEnd() { float loss1 = CrossEntropyLossForwardPass( prompt.tokens, prompt.context_size, weights.get(), forward1, - inv_timescale, pool); + inv_timescale, pools.Pool()); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); grad.ZeroInit(); CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), - backward, inv_timescale, pool); + backward, inv_timescale, pools.Pool()); Complexify(weights.get(), c_weights.get()); auto func = [&]() { diff --git a/gemma/gemma.h b/gemma/gemma.h index fef4b7a..d1a33a6 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -252,6 +252,10 @@ class Gemma { void GenerateImageTokens(const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens); + void SetMatMulVerbosity(int verbosity) { + if (verbosity >= 2) env_.print_best = true; + } + private: MatMulEnv env_; diff --git a/gemma/run.cc b/gemma/run.cc index 712bdcb..2052100 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -225,7 +225,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } if (end_of_turn_seen && abs_pos > 0) { // If we have seen an end_of_turn token, we need to rewind abs_pos by one - // more, because we will pre-pend it again to the prompt in + // more, because we will prepend it again to the prompt in // WrapAndTokenize. abs_pos--; } @@ -236,14 +236,13 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { PROFILER_ZONE("Run.misc"); - // TODO: remove once MatMul is updated. - app.max_packages = 1; // Note that num_threads is an upper bound; we also limit to the number of // detected and enabled cores. NestedPools pools = CreatePools(app); Allocator::Init(pools.Topology()); Gemma model = CreateGemma(loader, pools); + model.SetMatMulVerbosity(app.verbosity); KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index fa38c50..4a4d489 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -117,17 +117,18 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, } void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, - std::vector& times) { + std::vector& times, MMPerKey* per_key) { std::sort(times.begin(), times.end()); // bench_dnn reports the best and average, but the median seems more // consistent and resistant to outliers. const double elapsed = times[times.size() / 2]; - const double ratio = elapsed / (times[0] + 1E-6); // vs best, avoid / 0 + const double vs_best = elapsed / (times[0] + 1E-6); // avoid / 0 const size_t num_b = B_extents.Area(); - // FMA counts as two FLOP. - fprintf(stderr, "%.1f\t(med %.3f ms = %0.2fx min)\n", - 2 * 1E-9 * A_extents.rows * num_b / elapsed, elapsed * 1E3, ratio); + const double flops = 2 * A_extents.rows * num_b / elapsed; // FMA = 2 ops + + fprintf(stderr, "\t%.1f GFLOPS %.3f ms %0.2fx\n", flops * 1E-9, elapsed * 1E3, + vs_best); } // Generates inputs and prints observed throughput of MatMul. @@ -135,15 +136,18 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, template void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { hwy::ThreadPool& pool = env.parallel.Pools().Pool(0); - fprintf(stderr, "\nBenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", - M, K, N, add, TypeName(), TypeName()); + if (env.print_config || env.print_measurement) { + fprintf(stderr, "\n"); + } + fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N, + add, TypeName(), TypeName()); const Extents2D A_extents(M, K); const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - RowVectorBatch c_slow_batch(C_extents); - RowVectorBatch c_batch(C_extents); + RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_batch = AllocateAlignedRows(C_extents); std::unique_ptr> add_storage; if (add) { @@ -161,27 +165,40 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const float* add_row = add ? add_storage->data_scale1() : nullptr; const RowPtrF C = RowPtrFromBatch(c_batch); - constexpr size_t kSamples = 20; + // Fewer reps for large batch sizes, which take longer. + const size_t num_samples = M < 32 ? 20 : 12; std::vector times; - times.reserve(kSamples); + times.reserve(num_samples); + + // Ensure usage conditions are set before autotuning. Both binding and + // spinning may materially affect the choice of config. No harm in calling + // BindB/C if there is a single package: they will be a no-op. + BindB(B_extents.rows, B, env.parallel); + BindC(A_extents.rows, C, env.parallel); Tristate use_spinning = Tristate::kDefault; env.parallel.Pools().MaybeStartSpinning(use_spinning); + // env.print_config = true; + // env.print_measurement = true; + env.print_best = true; + double keep = 0.0; + MMPerKey* per_key; // Until enough samples collected *after* autotuning finished: - while (times.size() < kSamples) { + while (times.size() < num_samples) { const double t0 = hwy::platform::Now(); - MatMul(A, B, add_row, env, C); + per_key = MatMul(A, B, add_row, env, C); const double t1 = hwy::platform::Now(); double elapsed = t1 - t0; keep += C.Row(0)[hwy::Unpredictable1()]; - times.push_back(elapsed); + // Only record times after autotuning finished. + if (per_key->autotune.Best()) times.push_back(elapsed); } hwy::PreventElision(keep); env.parallel.Pools().MaybeStopSpinning(use_spinning); - PrintSpeed(A_extents, B_extents, times); + PrintSpeed(A_extents, B_extents, times, per_key); } using F32 = float; @@ -189,29 +206,31 @@ using SFP = SfpStream; void BenchAllMatMul() { if (first_target == 0) first_target = HWY_TARGET; - if (HWY_TARGET != first_target) return; - - for (size_t max_packages : {/*1,*/ 2}) { - const size_t max_threads = 0; // no limit - NestedPools pools(max_threads, Tristate::kDefault, - BoundedSlice(0, max_packages)); -#if GEMMA_DISABLE_TOPOLOGY - if (max_packages == 2) break; // we only have one package -#else - // If less than the limit, we have already tested all num_packages. - if (pools.Topology().FullTopology().packages.size() < max_packages) break; -#endif - fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages, - pools.TopologyString(), pools.PinString()); + // Disable the best-target-only limitation. + // if (HWY_TARGET != first_target) return; - Allocator::Init(pools.Topology()); - MatMulEnv env(pools); + // Skip EMU128 (10x slower than SSE4 for SFP) and older x86. + if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 || + HWY_TARGET == HWY_SSE2) { + return; + } - for (size_t batch_size : {1, 4, 128, 512}) { - constexpr bool kAdd = false; - BenchMatMul(batch_size, 24576, 3072, kAdd, env); - BenchMatMul(batch_size, 3072, 24576, kAdd, env); - } + const size_t max_threads = 0; // no limit + const BoundedSlice package_slice(0, 1); // all packages/sockets + const BoundedSlice cluster_slice(0, 1); // all clusters/CCX + const BoundedSlice lp_slice(0, 1); // default to all cores (per package). + NestedPools pools(max_threads, Tristate::kDefault, package_slice, + cluster_slice, lp_slice); + fprintf(stderr, "BenchAllMatMul %s %s\n", pools.TopologyString(), + pools.PinString()); + + Allocator::Init(pools.Topology(), /*enable_bind=*/true); + MatMulEnv env(pools); + + for (size_t batch_size : {1, 4, 128, 512}) { + constexpr bool kAdd = false; + BenchMatMul(batch_size, 24576, 3072, kAdd, env); + BenchMatMul(batch_size, 3072, 24576, kAdd, env); } PROFILER_PRINT_RESULTS(); diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 770ad5a..6533edb 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -28,6 +28,7 @@ #include "compression/shared.h" #include "util/allocator.h" +#include "util/app.h" #include "util/test_util.h" #include "util/threading.h" #include "hwy/base.h" @@ -999,6 +1000,8 @@ struct TestShortDotsT { const size_t N = hn::Lanes(d); const hn::ScalableTag df; // for CallDot + NestedPools pools = CreatePools(AppArgs()); + Allocator::Init(pools.Topology()); CompressWorkingSet work; std::mt19937 rng; rng.seed(12345); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 18d4e6b..9de9357 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -16,9 +16,18 @@ #include #include +#include +#include + +#include "compression/shared.h" #include "ops/matmul.h" // IWYU pragma: export #include "util/allocator.h" #include "util/basics.h" +#include "util/threading.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE) @@ -31,477 +40,1312 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" -#include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -// Loads two vectors at a time with element type hn::TFromD from a row of -// transposed B. Called in a loop over col_ab. No bounds checking because -// `kRow` is from B columns, which we checked is a multiple of `kRegCols`. -template -class BRow { - static_assert(kRow < kRegRows); // which unrolled instance we are +// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. +template > +static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { + // Promoting odd means clearing the lower 16 bits. Doing this via AND + // requires a second input vector, which we prefer to avoid due to high + // register pressure. Unfortunately `hn::IfThenElseZero` and + // `IfThenZeroElse` are 'optimized' back to AND, hence resort to assembly. + // Note that SVE also has separate mask registers, but it anyway uses the + // native BF16 dot product code path. +#if HWY_TARGET < HWY_AVX2 + const hn::Repartition du16; + const auto odd = static_cast<__mmask32>(0xAAAAAAAAu); // 10..10 (32 lanes) + // In-out because this is called after PromoteEvenTo, when we can clobber + // the original bf16 input. + auto u16 = hn::BitCast(du16, vbf).raw; + // Odd u16 lanes are set to the input and even lanes are zero. + asm("vmovdqu16 %[U16], %[U16]%{%[ODD]%}%{z%};" + : [U16] "+v"(u16) // AVX-512 reg + : [ODD] "Yk"(odd)); // mask reg except k0 (not writable) + return hn::BitCast(df, hn::VFromD{u16}); +#else + return hn::PromoteOddTo(df, vbf); +#endif +} + +// Tag classes, passed to `MMKernel::A2C0` to choose between writing one +// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the +// first kc result to partial, or accumulating the next kc result into partial +// via `MMAddHorizontalSumsIntoPartial`. +struct MMSetC {}; +struct MMSetPartial {}; +struct MMAddPartial {}; +// Stores horizontal sums of up to 16 vectors via transpose. +template +class MMStoreHorizontalSumsIntoC { public: - BRow(const ConstMat& B, size_t row_b) - : B_(MakeSpan(B.ptr, B.ofs + B.Extents().Area())), - B_ofs_(B.Row(HWY_MIN(row_b + kRow, B.Extents().rows - 1))) {} + static_assert(kNR == 4); // for `StoreInterleaved4` - template > - HWY_INLINE void Load2(DR d, size_t col_ab, VR& b0, VR& b1) const { - Decompress2(d, B_, B_ofs_ + col_ab, b0, b1); + // Computes horizontal sums of `kRowsAC x kNR` vectors and stores into + // `C` starting at `(row_c, col_c)`. + // + // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a + // transposed B row vector indexed by `c`. Their elements are thus a subset + // of the terms of the dot product constituting the final `C[r, c]` result. + // Thus we compute the horizontal sums of each `Crc`. The elements may be + // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but + // this does not change their horizontal sum. + template > + HWY_INLINE void operator()(DF df, // + VF C00, VF C01, VF C02, VF C03, // + VF C10, VF C11, VF C12, VF C13, // + VF C20, VF C21, VF C22, VF C23, // + VF C30, VF C31, VF C32, VF C33, // + const size_t row_c, const size_t col_c, + const MMArgs& args) const { + float buf[16 * hn::MaxLanes(df)]; + const size_t N = hn::Lanes(df); + // Horizontal reductions (`ReduceSum`) are rather expensive, entailing + // log(N) operations for vectors of length N. Because `kNR` == 4, we + // instead use `StoreInterleaved4` for a vector length-agnostic + // 'transpose': `buf[0, 4 * N)` holds `C00[0], C01[0], C02[0], C03[0], + // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], + // C03[N-1]`. + MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf); + MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf); + MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf); + MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf); + // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in + // the elements of one V4. We have four independent rows `r`, hence the + // code is effectively unrolled, which increases throughput. + const hn::CappedTag d4; + using V4 = hn::Vec; + // Store to four elements per row of `partial`. + // No loop is required because vectors are at least 4*32 bits. + V4 sum0 = MaybeLoad<0>(d4, N, buf); + V4 sum1 = MaybeLoad<1>(d4, N, buf); + V4 sum2 = MaybeLoad<2>(d4, N, buf); + V4 sum3 = MaybeLoad<3>(d4, N, buf); + + for (size_t lane = 1; lane < N; ++lane) { + sum0 = MaybeAdd<0>(d4, N, sum0, buf + kNR * lane); + sum1 = MaybeAdd<1>(d4, N, sum1, buf + kNR * lane); + sum2 = MaybeAdd<2>(d4, N, sum2, buf + kNR * lane); + sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane); + } + const V4 vscale = hn::Set(d4, args.scale); + V4 vadd = hn::Zero(d4); + if constexpr (kAdd) { + vadd = hn::Load(d4, args.add + col_c); + } + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, args.C, row_c, col_c); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, args.C, row_c, col_c); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, args.C, row_c, col_c); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, args.C, row_c, col_c); } private: - PackedSpan B_; - const size_t B_ofs_; -}; + // These helper functions hoist if() out of the main code below. They have + // no effect if kRow >= kRowsAC. + template > + static HWY_INLINE void MaybeStoreInterleaved4(DD dd, size_t N, VD Cr0, VD Cr1, + VD Cr2, VD Cr3, + float* HWY_RESTRICT buf) { + if constexpr (kRow < kRowsAC) { + hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, dd, buf + 4 * kRow * N); + } + } -// Loads *two* row vectors from A via `Decompress2`, widens to f32, multiplies -// element-wise with `kRegRows` x 2 row vectors from transposed B, and adds -// them to `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a -// subset of the terms of the dot products that make up the MatMul result at -// `r,c`. No-op for the bottom-most rows whose `kRow >= kNumRows`. -// -// This approach is atypical because it requires a horizontal sum, for which we -// introduce a fast and new(?) vector-length agnostic 'transpose', see -// `AddHorizontalSums`. Most MatMul instead broadcast one element from A and -// multiply with one element from N columns in B to obtain N columns of C. -// This is a poor fit for our setting: -// - `Decompress2` decompresses two vectors at a time; -// - B is column-major, so unit-stride SIMD loads return a column, not values -// from different columns, i.e. a row. -// - `ReorderWidenMulAccumulate` is important for bf16 performance, but its -// pairwise adds would add together unrelated terms. -// The first two could be fixed in a packing stage, which is not implemented -// yet, and might not be necessary otherwise. The third seems a fundamental -// mismatch. However, pairwise adds are fine in our setting because C lanes are -// the terms of a single dot product, which can be reordered or pre-reduced. -template -class ALoadAccumulate { - public: - static_assert(kRow < kRegRows); // which unrolled instance we are - // `First` and `Next` handle a single row of A, so the horizontal sums of - // their `C0..3` are the (partial) dot products for 4 consecutive values in - // one row of C. - static_assert(kRegCols == 4); - - ALoadAccumulate(const ConstMat& A, size_t row_ac) - : A_(MakeSpan(A.ptr, A.ofs + A.Extents().Area())), - A_ofs_(A.Row(HWY_MIN(row_ac + kRow, A.Extents().rows - 1))) {} - - // First iteration, col_ab = 0: initialize C0..3 instead of updating them. - template , HWY_IF_F32_D(DM)> - HWY_INLINE void First(DM dm, // - const VM b00, const VM b01, const VM b10, const VM b11, - const VM b20, const VM b21, const VM b30, const VM b31, - VM& C0, VM& C1, VM& C2, VM& C3) const { - static_assert(kNumRows <= kRegRows); // How many rows actually present - if constexpr (kRow < kNumRows) { - VM a0, a1; - Decompress2(dm, A_, A_ofs_, a0, a1); - - static_assert(kRegCols == 4); - C0 = hn::Mul(a0, b00); - C1 = hn::Mul(a0, b10); - C2 = hn::Mul(a0, b20); - C3 = hn::Mul(a0, b30); - C0 = hn::MulAdd(a1, b01, C0); - C1 = hn::MulAdd(a1, b11, C1); - C2 = hn::MulAdd(a1, b21, C2); - C3 = hn::MulAdd(a1, b31, C3); - } - } - - // Same as above, only called if MulT == BF16. - template , - HWY_IF_BF16_D(DM), class DF = hn::Repartition, - class VF = hn::Vec> - HWY_INLINE void First(DM dm, // - const VM b00, const VM b01, const VM b10, const VM b11, - const VM b20, const VM b21, const VM b30, const VM b31, - VF& C0, VF& C1, VF& C2, VF& C3) const { - static_assert(kNumRows <= kRegRows); // How many rows actually present - if constexpr (kRow < kNumRows) { - VM a0, a1; - Decompress2(dm, A_, A_ofs_, a0, a1); - - const DF df; - - static_assert(kRegCols == 4); - C0 = hn::WidenMulPairwiseAdd(df, a0, b00); - C1 = hn::WidenMulPairwiseAdd(df, a0, b10); - C2 = hn::WidenMulPairwiseAdd(df, a0, b20); - C3 = hn::WidenMulPairwiseAdd(df, a0, b30); - if constexpr (HWY_NATIVE_DOT_BF16) { - // Native ReorderWidenMulAccumulate adds to C0..3 for free. - VF unused_sum1 = hn::Zero(df); - C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); - C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); - C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); - C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); - // Ensure sum1 was indeed unused. - HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); - } else { - C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01)); - C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11)); - C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21)); - C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31)); - } + // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. + template > + static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N, + const float* HWY_RESTRICT buf) { + if constexpr (kRow < kRowsAC) { + return hn::Load(d4, buf + 4 * kRow * N); + } else { + return hn::Zero(d4); } } - // Non-first iteration: accumulate into C0..3. - template , HWY_IF_F32_D(DM)> - HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01, - const VM b10, const VM b11, const VM b20, const VM b21, - const VM b30, const VM b31, VM& C0, VM& C1, VM& C2, - VM& C3) const { - static_assert(kNumRows <= kRegRows); // How many rows actually present - HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration. - if constexpr (kRow < kNumRows) { - VM a0, a1; - Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); - - static_assert(kRegCols == 4); - C0 = hn::MulAdd(a0, b00, C0); - C1 = hn::MulAdd(a0, b10, C1); - C2 = hn::MulAdd(a0, b20, C2); - C3 = hn::MulAdd(a0, b30, C3); - C0 = hn::MulAdd(a1, b01, C0); - C1 = hn::MulAdd(a1, b11, C1); - C2 = hn::MulAdd(a1, b21, C2); - C3 = hn::MulAdd(a1, b31, C3); - } - } - - // Same as above, only called if MulT == BF16. - template , - HWY_IF_BF16_D(DM), class DF = hn::Repartition, - class VF = hn::Vec> - HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01, - const VM b10, const VM b11, const VM b20, const VM b21, - const VM b30, const VM b31, VF& C0, VF& C1, VF& C2, - VF& C3) const { - static_assert(kNumRows <= kRegRows); // How many rows actually present - HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration. - if constexpr (kRow < kNumRows) { - VM a0, a1; - Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); - - const DF df; - - static_assert(kRegCols == 4); - if constexpr (HWY_NATIVE_DOT_BF16) { - // Native ReorderWidenMulAccumulate adds to C0..3 for free. - VF unused_sum1 = hn::Zero(df); - C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1); - C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1); - C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1); - C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1); - C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); - C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); - C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); - C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); - // Ensure sum1 was indeed unused. - HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); - } else { - C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a0, b00)); - C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a0, b10)); - C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a0, b20)); - C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a0, b30)); - C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01)); - C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11)); - C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21)); - C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31)); + template > + static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum, + const float* HWY_RESTRICT buf) { + if constexpr (kRow < kRowsAC) { + return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N)); + } else { + return sum; + } + } + + template > + static HWY_INLINE void MaybeScaleAndStore(D4 d4, V4 sum, V4 vscale, V4 vadd, + const RowPtrF& C, + const size_t row_c, + const size_t col_c) { + if constexpr (kRow < kRowsAC) { + float* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c; + hn::Store(hn::MulAdd(sum, vscale, vadd), d4, pos); + } + } +}; // MMStoreHorizontalSumsIntoC + +// Accumulates horizontal sums of up to 16 vectors via transpose. +template +class MMAddHorizontalSumsIntoPartial { + public: + static_assert(kNR == 4); // for `StoreInterleaved4` + + // Computes horizontal sums of `kRowsAC x kNR` vectors and accumulates + // into `partial` starting at `(row_c, col_c)`. + // + // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a + // transposed B row vector indexed by `c`. Their elements are thus a subset + // of the terms of the dot product constituting the final `C[r, c]` result. + // Thus we compute the horizontal sums of each `Crc`. The elements may be + // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but + // this does not change their horizontal sum. + template > + HWY_INLINE void operator()(DF df, // + VF F00, VF F01, VF F02, VF F03, // + VF F10, VF F11, VF F12, VF F13, // + VF F20, VF F21, VF F22, VF F23, // + VF F30, VF F31, VF F32, VF F33, // + const size_t row_c, const size_t col_c, + const RowPtrD& partial) const { + // We accumulate in 64-bit to avoid loss of precision. + static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64"); + + const hn::Repartition dd; + double buf[16 * hn::MaxLanes(dd)]; + using VD = hn::Vec; + const size_t ND = hn::Lanes(dd); + VD C00 = SumOfPromotedPairs(dd, F00); + VD C01 = SumOfPromotedPairs(dd, F01); + VD C02 = SumOfPromotedPairs(dd, F02); + VD C03 = SumOfPromotedPairs(dd, F03); + VD C10 = SumOfPromotedPairs(dd, F10); + VD C11 = SumOfPromotedPairs(dd, F11); + VD C12 = SumOfPromotedPairs(dd, F12); + VD C13 = SumOfPromotedPairs(dd, F13); + VD C20 = SumOfPromotedPairs(dd, F20); + VD C21 = SumOfPromotedPairs(dd, F21); + VD C22 = SumOfPromotedPairs(dd, F22); + VD C23 = SumOfPromotedPairs(dd, F23); + VD C30 = SumOfPromotedPairs(dd, F30); + VD C31 = SumOfPromotedPairs(dd, F31); + VD C32 = SumOfPromotedPairs(dd, F32); + VD C33 = SumOfPromotedPairs(dd, F33); + + // Horizontal reductions (`ReduceSum`) are rather expensive, entailing + // log(N) operations for vectors of length N. Because `kNR` == 4, we + // instead use `StoreInterleaved4` for a vector length-agnostic + // 'transpose': `buf[0, 4 * N)` holds `C00[0], C01[0], C02[0], C03[0], + // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], + // C03[N-1]`. + MaybeStoreInterleaved4<0>(dd, ND, C00, C01, C02, C03, buf); + MaybeStoreInterleaved4<1>(dd, ND, C10, C11, C12, C13, buf); + MaybeStoreInterleaved4<2>(dd, ND, C20, C21, C22, C23, buf); + MaybeStoreInterleaved4<3>(dd, ND, C30, C31, C32, C33, buf); + // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in + // the elements of one V4. We have four independent rows `r`, hence the + // code is effectively unrolled, which increases throughput. + const hn::CappedTag d4; + using V4 = hn::Vec; + // Store to four elements per row of `partial`. + // Loop is required because vectors may be smaller than 4*64 bits. + for (size_t c = 0; c < kNR; c += hn::Lanes(d4)) { + V4 sum0 = MaybeLoad<0>(d4, ND, buf + c); + V4 sum1 = MaybeLoad<1>(d4, ND, buf + c); + V4 sum2 = MaybeLoad<2>(d4, ND, buf + c); + V4 sum3 = MaybeLoad<3>(d4, ND, buf + c); + + for (size_t lane = 1; lane < ND; ++lane) { + sum0 = MaybeAdd<0>(d4, ND, sum0, buf + c + kNR * lane); + sum1 = MaybeAdd<1>(d4, ND, sum1, buf + c + kNR * lane); + sum2 = MaybeAdd<2>(d4, ND, sum2, buf + c + kNR * lane); + sum3 = MaybeAdd<3>(d4, ND, sum3, buf + c + kNR * lane); } + MaybeAddStore<0>(d4, sum0, partial, row_c, col_c + c); + MaybeAddStore<1>(d4, sum1, partial, row_c, col_c + c); + MaybeAddStore<2>(d4, sum2, partial, row_c, col_c + c); + MaybeAddStore<3>(d4, sum3, partial, row_c, col_c + c); } } private: - PackedSpan A_; - const size_t A_ofs_; -}; // ALoadAccumulate - -// Sets a `kRegRows` x `kRegCols` tile of C to `add[add_ofs + c]` if kAdd, -// otherwise 0. -// `add` has no scale and is a row vector with A.cols entries if `kAdd`, -// otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB, -// hence we pass it as a separate argument. -template -HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs, - float* HWY_RESTRICT pos_c, size_t stride_c) { - const hn::FixedTag d4; - for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) { - if constexpr (kAdd) { - hn::StoreU(hn::LoadU(d4, add + add_ofs), d4, pos_c + r * stride_c); - } else { - hn::StoreU(hn::Zero(d4), d4, pos_c + r * stride_c); - } + // Converts lanes to double and adds pairs of them to obtain a vector with the + // same horizontal sum, but element type double. + template , + class DF = hn::Repartition, class VF = hn::Vec> + static HWY_INLINE VD SumOfPromotedPairs(DD dd, VF f) { + // TODO: SVE could PromoteEvenTo. + const VD d0 = hn::PromoteLowerTo(dd, f); + const VD d1 = hn::PromoteUpperTo(dd, f); + return hn::Add(d0, d1); } -} -// Accumulates into a tile of C. -template -class AddHorizontalSums { - // These helper functions hoist if() out of the main code below. They have no - // effect if kRow >= kNumRows. - template > - static void MaybeStoreInterleaved4(DF df, size_t N, VF Cr0, VF Cr1, VF Cr2, - VF Cr3, float* HWY_RESTRICT buf) { - if constexpr (kRow < kNumRows) { - hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, df, buf + 4 * kRow * N); + // These helper functions hoist if() out of the main code below. They have + // no effect if kRow >= kRowsAC. + template > + static HWY_INLINE void MaybeStoreInterleaved4(DD dd, size_t N, VD Cr0, VD Cr1, + VD Cr2, VD Cr3, + double* HWY_RESTRICT buf) { + if constexpr (kRow < kRowsAC) { + hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, dd, buf + 4 * kRow * N); } } // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. template > - static V4 MaybeLoad(D4 df, size_t N, const float* HWY_RESTRICT buf) { - if constexpr (kRow < kNumRows) { - return hn::Load(df, buf + 4 * kRow * N); + static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N, + const double* HWY_RESTRICT buf) { + if constexpr (kRow < kRowsAC) { + return hn::Load(d4, buf + 4 * kRow * N); } else { - return hn::Zero(df); + return hn::Zero(d4); } } template > - static V4 MaybeAdd(D4 df, size_t N, V4 sum, const float* HWY_RESTRICT buf) { - if constexpr (kRow < kNumRows) { - return hn::Add(sum, hn::Load(df, buf + 4 * kRow * N)); + static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum, + const double* HWY_RESTRICT buf) { + if constexpr (kRow < kRowsAC) { + return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N)); } else { return sum; } } template > - static void MaybeMulAdd(D4 df, V4 sum, V4 scale, float* HWY_RESTRICT tile_c, - const size_t stride_c) { - if constexpr (kRow < kNumRows) { - const V4 prev_c = hn::LoadU(df, tile_c + kRow * stride_c); - hn::StoreU(hn::MulAdd(sum, scale, prev_c), df, tile_c + kRow * stride_c); + static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, const RowPtrD& partial, + const size_t row_c, const size_t col_c) { + if constexpr (kRow < kRowsAC) { + double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c; + if constexpr (hwy::IsSame()) { + hn::Store(sum, d4, pos); + } else { + static_assert(hwy::IsSame()); + const V4 prev = hn::Load(d4, pos); + hn::Store(hn::Add(sum, prev), d4, pos); + } } } +}; // MMAddHorizontalSumsIntoPartial +// Stateless, wraps member functions. +class MMKernel { public: - // Adds the contribution from `Crc` accumulators to the 4x4 tile of C whose - // top left is `tile_c`, after multiplying by `scale`, which is the product of - // the scales of A and B. C is always f32 to ensure sufficient precision. + // Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because + // we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile. + // In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions + // that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`, + // or less on ISAs with fewer registers, or for the last few rows of A. + static constexpr size_t kMaxMR = 4; + + // Calls `LoopOverKC` for each of `mc` rows of A in steps of `mr`. `A_view` + // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. + // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. + template + static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view, + size_t mr, const IndexRange& range_mc, + const size_t row_b, size_t kc, Tag tag, + const MMArgs& args) { + HWY_DASSERT(1 <= mr && mr <= kMaxMR); + const size_t row0 = range_mc.begin(); + const size_t mc = range_mc.Num(); + size_t imc = 0; + + // M == 1, or x86 with 8 SIMD registers: + if (HWY_UNLIKELY(mr == 1)) { + for (; imc < mc; ++imc) { + LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + } + return; + } + + // AVX2 (16 registers) + if (HWY_UNLIKELY(mr == 2)) { + if (HWY_LIKELY(mc >= 2)) { + for (; imc <= mc - 2; imc += 2) { + LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + } + } + if (HWY_UNLIKELY(imc != mc)) { + LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + } + return; + } + + HWY_DASSERT(mr == 4); + if (HWY_LIKELY(mc >= 4)) { + for (; imc <= mc - 4; imc += 4) { + LoopOverKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + } + } + const size_t remainder_mc = mc - imc; + HWY_DASSERT(remainder_mc < 4); + if (HWY_UNLIKELY(remainder_mc & 2)) { + LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + imc += 2; + } + if (HWY_UNLIKELY(remainder_mc & 1)) { + LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); + imc += 1; + } + HWY_DASSERT(imc == mc); + } + + private: + // Element-wise multiplies a vector from one row of A with `kNR` vectors, + // each from a row of transposed B, and adds them to `kNR` fp32 `Cc` + // vectors. The lanes of `Cc` are thus a subset of the terms of the dot + // product which is the MatMul result at column `c`. // - // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a - // B column vector indexed by `c`. Their elements are thus a subset of the - // terms of the dot product constituting the final `C[r, c]` result. Thus we - // compute the horizontal sums of each `Crc`. The elements may be permuted - // because we multiply bf16 via `ReorderWidenMulAccumulate`, but this does - // not change their horizontal sum. `buf` is thread-local space for 16 `VF`. - template > - HWY_INLINE void operator()(DF df, float scale, // - VF C00, VF C01, VF C02, VF C03, // - VF C10, VF C11, VF C12, VF C13, // - VF C20, VF C21, VF C22, VF C23, // - VF C30, VF C31, VF C32, VF C33, // - float* HWY_RESTRICT buf, - float* HWY_RESTRICT tile_c, - size_t stride_c) const { - const size_t N = hn::Lanes(df); - // Horizontal reductions (`ReduceSum`) are rather expensive, entailing - // log(N) operations for vectors of length N. Because kRegCols == 4, we can - // instead use `StoreInterleaved4` for a vector length-agnostic 'transpose': - // `buf[0, 4 * N)` holds C00[0], C01[0], C02[0], C03[0], - // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], C03[N-1]. - MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf); - MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf); - MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf); - MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf); - // Adding N consecutive V4 yields four horizontal sums of Cr0, Cr1, Cr2, Cr3 - // in the elements of one V4. We have four independent rows `r`, hence the - // code is effectively unrolled, which increases throughput. - const hn::FixedTag d4; - using V4 = hn::Vec; - V4 sum0 = MaybeLoad<0>(d4, N, buf); - V4 sum1 = MaybeLoad<1>(d4, N, buf); - V4 sum2 = MaybeLoad<2>(d4, N, buf); - V4 sum3 = MaybeLoad<3>(d4, N, buf); + // Why elementwise, when most MatMul instead broadcast one element from A and + // multiply with one element from kr columns in B to obtain kr columns of C? + // We double the compute throughput on NEON_BF16/SVE/AVX3_ZEN4 by using the + // bf16 * bf16 + f32 `ReorderWidenMulAccumulate`. However, this involves + // pairwise adds, whereas the kr-column approach requires that lanes remain + // separate. Our elementwise approach is fine with pairwise adds because they + // do not change the horizontal sum. However, horizontal sums can be costly, + // so we introduce a fast and new(?) vector-length agnostic 'transpose', see + // `MMAddHorizontalSumsIntoPartial`. + template , + class DF = hn::Repartition, class VF = hn::Vec> + static HWY_INLINE void ElementwiseMulAcc(DBF dbf, VBF a, VBF b0, VBF b1, + VBF b2, VBF b3, VF& C0, VF& C1, + VF& C2, VF& C3) { + // This handles a single row of A, so the horizontal sums of `C0..3` are the + // (partial) dot products for 4 consecutive values in one row of C. + static_assert(kNR == 4); + + HWY_DASSERT(HWY_NATIVE_DOT_BF16); - for (size_t i = 1; i < N; ++i) { - sum0 = MaybeAdd<0>(d4, N, sum0, buf + 4 * i); - sum1 = MaybeAdd<1>(d4, N, sum1, buf + 4 * i); - sum2 = MaybeAdd<2>(d4, N, sum2, buf + 4 * i); - sum3 = MaybeAdd<3>(d4, N, sum3, buf + 4 * i); + const DF df; + VF unused_sum1 = hn::Zero(df); + // When implemented natively, this op includes 'free' f32 accumulation. + C0 = hn::ReorderWidenMulAccumulate(df, a, b0, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a, b1, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a, b2, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a, b3, C3, unused_sum1); + // Ensure unused_sum1 was indeed unused. + HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + } + + // Like `ElementwiseMulAcc`, but splits BF16 inputs into odd and even f32 + // for use with FMA. Also handles two rows at a time to hide the FMA latency + // (we assume 4 cycles and dual-issue) before writing `C00` again. + template , + class DF = hn::Repartition, class VF = hn::Vec> + static HWY_INLINE void ElementwiseMulAcc2(DBF dbf, VBF a0, VBF a1, VF b0o, + VF b0e, VF b1o, VF b1e, VF b2o, + VF b2e, VF b3o, VF b3e, VF& C00, + VF& C01, VF& C02, VF& C03, VF& C10, + VF& C11, VF& C12, VF& C13) { + const DF df; + HWY_DASSERT(!HWY_NATIVE_DOT_BF16); + // Avoid `ReorderWidenMulAccumulate` because it requires extra adds for + // the two outputs, and `WidenMulPairwiseAdd` because it wastes an + // opportunity for a free f32 add via FMA, and `MulOddAdd` because we want + // to avoid an extra register for a constant. Use scoping to reduce register + // pressure and avoid spills on 32-register targets. Register usage: + // 4 for a0, a1, a0e, a1e; 8 for `b*`, 16 for `C*` = 28. + { + const VF a0e = hn::PromoteEvenTo(df, a0); + C00 = hn::MulAdd(a0e, b0e, C00); + C01 = hn::MulAdd(a0e, b1e, C01); + C02 = hn::MulAdd(a0e, b2e, C02); + C03 = hn::MulAdd(a0e, b3e, C03); + } + { + const VF a1e = hn::PromoteEvenTo(df, a1); + C10 = hn::MulAdd(a1e, b0e, C10); + C11 = hn::MulAdd(a1e, b1e, C11); + C12 = hn::MulAdd(a1e, b2e, C12); + C13 = hn::MulAdd(a1e, b3e, C13); + } + { + const VF a0o = FastPromoteOddTo(df, a0); + C00 = hn::MulAdd(a0o, b0o, C00); + C01 = hn::MulAdd(a0o, b1o, C01); + C02 = hn::MulAdd(a0o, b2o, C02); + C03 = hn::MulAdd(a0o, b3o, C03); + } + { + const VF a1o = FastPromoteOddTo(df, a1); + C10 = hn::MulAdd(a1o, b0o, C10); + C11 = hn::MulAdd(a1o, b1o, C11); + C12 = hn::MulAdd(a1o, b2o, C12); + C13 = hn::MulAdd(a1o, b3o, C13); + } + } + + // Innermost loop over `kc` columns (typically 1024-4096) in steps of one + // vector, for `kRowsAC` rows of `A_view` from range_mc-relative `imc` and + // `B_view` from row 0 (both at column 0). Updates a `kRowsAC x kNR` tile + // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be + // BF16 so we can load directly without `Decompress2`, which is expensive for + // NUQ and requires 2x unrolling, which requires more loads. + template + static HWY_INLINE void LoopOverKC(const RowPtrBF& A_view, + const RowPtrBF& B_view, size_t row_ac, + size_t imc, size_t col_c, size_t kc, + Tag tag, const MMArgs& args) { + const hn::ScalableTag dbf; + using VBF = hn::Vec; + const size_t NBF = hn::Lanes(dbf); + + HWY_DASSERT(kRowsAC <= kMaxMR); + HWY_DASSERT(col_c % kNR == 0); + // Rows are aligned to `kMaxMR`, except for the last tile of A. + + // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. + static_assert(kNR == 4); + const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0); + const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; + const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; + const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; + const BF16* HWY_RESTRICT br0 = B_view.Row(0); + const BF16* HWY_RESTRICT br1 = B_view.Row(1); + const BF16* HWY_RESTRICT br2 = B_view.Row(2); + const BF16* HWY_RESTRICT br3 = B_view.Row(3); + + // Ensure `A` and `B` were zero-padded by `DecompressAndZeroPad`. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { + { + HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); + } + if constexpr (kRowsAC > 1) { + HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); + } + if constexpr (kRowsAC > 2) { + HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); + } + if constexpr (kRowsAC > 3) { + HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); + } + HWY_DASSERT(hwy::ConvertScalarTo(br0[i]) == 0.0f); + HWY_DASSERT(hwy::ConvertScalarTo(br1[i]) == 0.0f); + HWY_DASSERT(hwy::ConvertScalarTo(br2[i]) == 0.0f); + HWY_DASSERT(hwy::ConvertScalarTo(br3[i]) == 0.0f); + } + } + + // Accumulate into f32. + const hn::Repartition df; + using VF = hn::Vec; + VF C00 = hn::Zero(df), C01 = hn::Zero(df), C02 = hn::Zero(df), + C03 = hn::Zero(df), C10 = hn::Zero(df), C11 = hn::Zero(df), + C12 = hn::Zero(df), C13 = hn::Zero(df), C20 = hn::Zero(df), + C21 = hn::Zero(df), C22 = hn::Zero(df), C23 = hn::Zero(df), + C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df), + C33 = hn::Zero(df); + + HWY_UNROLL(1) + for (size_t ikc = 0; ikc < kc; ikc += NBF) { + if constexpr (HWY_NATIVE_DOT_BF16) { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + ElementwiseMulAcc(dbf, a0, b0, b1, b2, b3, C00, C01, C02, C03); + } + if constexpr (kRowsAC > 1) { + const VBF a1 = hn::Load(dbf, ar1 + ikc); + ElementwiseMulAcc(dbf, a1, b0, b1, b2, b3, C10, C11, C12, C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + ElementwiseMulAcc(dbf, a2, b0, b1, b2, b3, C20, C21, C22, C23); + } + if constexpr (kRowsAC > 3) { + const VBF a3 = hn::Load(dbf, ar3 + ikc); + ElementwiseMulAcc(dbf, a3, b0, b1, b2, b3, C30, C31, C32, C33); + } + } else { + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } + + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; + ElementwiseMulAcc2(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o, + b3e, C00, C01, C02, C03, C10, C11, C12, C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; + ElementwiseMulAcc2(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o, + b3e, C20, C21, C22, C23, C30, C31, C32, C33); + } + } + } + + // This is a substantial fraction (about 1/3) of the total time, but is + // called frequently, so do not add a profiler zone. + + if constexpr (hwy::IsSame()) { + if (args.add) { + MMStoreHorizontalSumsIntoC()( + df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, + C31, C32, C33, row_ac, col_c, args); + } else { + MMStoreHorizontalSumsIntoC()( + df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, + C31, C32, C33, row_ac, col_c, args); + } + } else { + MMAddHorizontalSumsIntoPartial()( + df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, + C31, C32, C33, row_ac, col_c, args.partial); } - // Scale, then store to four elements per row of `tile_c`. - const V4 vscale = hn::Set(d4, scale); - MaybeMulAdd<0>(d4, sum0, vscale, tile_c, stride_c); - MaybeMulAdd<1>(d4, sum1, vscale, tile_c, stride_c); - MaybeMulAdd<2>(d4, sum2, vscale, tile_c, stride_c); - MaybeMulAdd<3>(d4, sum3, vscale, tile_c, stride_c); } }; -// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a -// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c). -// TODO: loop over sections instead of full rows and accumulate into `tile_c`. -// `buf` is 16 vectors of thread-local storage. -template -HWY_INLINE void MatMulTile(const ConstMat& A, const size_t row_ac, - const ConstMat& B, const size_t row_b_col_c, - const float scale, const float* HWY_RESTRICT add, - float* HWY_RESTRICT buf, const RowPtr& C) { - // Decompress A and B to which type, which will then be widened to f32, - // multiplied, added once into f32, then promoted to f64 and accumulated. - // NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are - // more efficient than f32 * f32 + f32 because they process twice as many - // lanes at a time. If available, we definitely want to use them. Otherwise, - // bf16 is still worthwhile if A (activations) are bf16: SFP weights are - // cheaper to decode to bf16, relative to the minor extra cost of promoting - // bf16 when multiplying. However, if A is f32, demoting to bf16 can be - // expensive unless we also have native bf16 dot. - using Raw = hwy::If(), BF16, float>; - const hn::ScalableTag dr; - using VR = hn::Vec; - const size_t NR = hn::Lanes(dr); - - const Range1D cols_ab(0, A.Extents().cols); - HWY_DASSERT(row_ac + kNumRows <= A.Extents().rows); - HWY_DASSERT(row_b_col_c + kNumRows <= B.Extents().rows); - HWY_DASSERT(cols_ab.end() % (2 * NR) == 0); - - static_assert(kRegRows == 4); - const BRow<0, MatTB> b_row0(B, row_b_col_c); - const BRow<1, MatTB> b_row1(B, row_b_col_c); - const BRow<2, MatTB> b_row2(B, row_b_col_c); - const BRow<3, MatTB> b_row3(B, row_b_col_c); - - const ALoadAccumulate<0, MatTA> a_row0(A, row_ac); - const ALoadAccumulate<1, MatTA> a_row1(A, row_ac); - const ALoadAccumulate<2, MatTA> a_row2(A, row_ac); - const ALoadAccumulate<3, MatTA> a_row3(A, row_ac); - - const hn::Repartition df; - using VF = hn::Vec; - VF C00, C01, C02, C03; - VF C10, C11, C12, C13; - VF C20, C21, C22, C23; - VF C30, C31, C32, C33; - - size_t col_ab = cols_ab.begin(); - { // First iteration initializes the `Crc` vectors. - VR b00, b01, b10, b11, b20, b21, b30, b31; - b_row0.Load2(dr, col_ab, b00, b01); - b_row1.Load2(dr, col_ab, b10, b11); - b_row2.Load2(dr, col_ab, b20, b21); - b_row3.Load2(dr, col_ab, b30, b31); - - a_row0.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, - C00, C01, C02, C03); - a_row1.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, - C10, C11, C12, C13); - a_row2.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, - C20, C21, C22, C23); - a_row3.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, - C30, C31, C32, C33); - col_ab += 2 * NR; - } - - // `2 * NR` per iteration because `Load2` returns two vectors. - HWY_UNROLL(1) - for (; col_ab < cols_ab.end(); col_ab += 2 * NR) { - VR b00, b01, b10, b11, b20, b21, b30, b31; - b_row0.Load2(dr, col_ab, b00, b01); - b_row1.Load2(dr, col_ab, b10, b11); - b_row2.Load2(dr, col_ab, b20, b21); - b_row3.Load2(dr, col_ab, b30, b31); - - a_row0.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, - b30, b31, C00, C01, C02, C03); - a_row1.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, - b30, b31, C10, C11, C12, C13); - a_row2.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, - b30, b31, C20, C21, C22, C23); - a_row3.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, - b30, b31, C30, C31, C32, C33); - } - - // TODO: hoist into outer loop. - float* HWY_RESTRICT C_tile = C.Row(row_ac) + row_b_col_c; - InitC(add, row_b_col_c, C_tile, C.Stride()); - - AddHorizontalSums()(df, scale, C00, C01, C02, C03, C10, C11, C12, - C13, C20, C21, C22, C23, C30, C31, C32, C33, - buf, C_tile, C.Stride()); -} +// Stateless, wraps member functions. +class MMScaleDemoteAdd { + public: + // Fills the `range_mc/range_nc` region of `outputs.C` by multiplying the + // same region of `outputs.partial` by `outputs.scale`, which is the product + // of the scales of A and B, demoting from f64 to f32, then if `outputs.add` + // is nonzero, adding it to each row. + // TODO: fuse with subsequent operations - function pointer? + // Although this region in `outputs.C` is not touched again, streaming stores + // do not help on SKX and Zen4. TODO: re-check this. + static HWY_INLINE void FillC(const IndexRange& range_mc, + const IndexRange& range_nc, const MMArgs& args) { + size_t row_c = range_mc.begin(); + if (args.add) { + constexpr bool kAdd = true; + if (range_mc.Num() >= 4) { + for (; row_c <= range_mc.end() - 4; row_c += 4) { + Do4Rows(row_c, range_nc, args); + } + } + for (; row_c < range_mc.end(); ++row_c) { + Do1Row(row_c, range_nc, args); + } + } else { + constexpr bool kAdd = false; + if (range_mc.Num() >= 4) { + for (; row_c <= range_mc.end() - 4; row_c += 4) { + Do4Rows(row_c, range_nc, args); + } + } + for (; row_c < range_mc.end(); ++row_c) { + Do1Row(row_c, range_nc, args); + } + } + } -template -HWY_NOINLINE void MatMulImpl(const ConstMat& A, const ConstMat& B, - const float* HWY_RESTRICT add, MatMulEnv& env, - const RowPtr& C) { - // PROFILER_ZONE("Matmul"); - HWY_DASSERT(A.Extents().cols == B.Extents().cols); - const size_t batch_size = A.Extents().rows; - HWY_DASSERT(C.Cols() % kRegCols == 0); - HWY_DASSERT(C.Stride() >= C.Cols()); - HWY_DASSERT(B.Extents().rows == C.Cols()); - - const float scale = A.scale * B.scale; - - // We currently write C directly, which touches more memory than fits in L3. - // TODO: add another level of loops to finish L3-sized pieces of C at a time. - const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); - const size_t tilesX = C.Cols() / kRegCols; - - env.Pool().Run( - 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { - // TODO: when using PerClusterPool, compute lp from outer and inner. - float* HWY_RESTRICT buf = env.Buf().Batch(thread); - const size_t tx = idx_tile % tilesX; - const size_t ty = idx_tile / tilesX; - const size_t row_ac = ty * kRegRows; - const size_t row_b_col_c = tx * kRegCols; - // How many rows of C are left to compute. If more than 4, this - // tile still only computes 4 rows. - const size_t num_rows = batch_size - row_ac; - HWY_DASSERT(num_rows != 0); - switch (num_rows) { - case 1: - MatMulTile<1, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); - break; - case 2: - MatMulTile<2, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); - break; - case 3: - MatMulTile<3, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); - break; - default: - MatMulTile<4, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); + private: + // Unrolled for 4 rows to reduce the number of loads from `add`. + template + static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, + const MMArgs& args) { + const hn::ScalableTag dd; + const hn::Rebind df; // result of DemoteTo + using VD = hn::Vec; + const size_t ND = hn::Lanes(dd); + const VD vscale = hn::Set(dd, args.scale); + + const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); + const double* HWY_RESTRICT pr1 = args.partial.Row(row_c + 1); + const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); + const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); + + float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0); + float* HWY_RESTRICT cr1 = args.C.Row(row_c + 1); + float* HWY_RESTRICT cr2 = args.C.Row(row_c + 2); + float* HWY_RESTRICT cr3 = args.C.Row(row_c + 3); + + // We manually unroll 2x for higher IPC in batch=1. + size_t col_c = range_nc.begin(); + if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { + HWY_UNROLL(1) + for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { + VD a0, a1; // unused if !kAdd + if constexpr (kAdd) { + // Promoting to double lets us fuse the Add into MulAdd. + a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c)); + a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND)); } - }); -} + const VD d00 = hn::Load(dd, pr0 + col_c); + const VD d01 = hn::Load(dd, pr0 + col_c + ND); + const VD d10 = hn::Load(dd, pr1 + col_c); + const VD d11 = hn::Load(dd, pr1 + col_c + ND); + const VD d20 = hn::Load(dd, pr2 + col_c); + const VD d21 = hn::Load(dd, pr2 + col_c + ND); + const VD d30 = hn::Load(dd, pr3 + col_c); + const VD d31 = hn::Load(dd, pr3 + col_c + ND); + VD m00, m01, m10, m11, m20, m21, m30, m31; + if constexpr (kAdd) { + m00 = hn::MulAdd(d00, vscale, a0); + m01 = hn::MulAdd(d01, vscale, a1); + m10 = hn::MulAdd(d10, vscale, a0); + m11 = hn::MulAdd(d11, vscale, a1); + m20 = hn::MulAdd(d20, vscale, a0); + m21 = hn::MulAdd(d21, vscale, a1); + m30 = hn::MulAdd(d30, vscale, a0); + m31 = hn::MulAdd(d31, vscale, a1); + } else { + m00 = hn::Mul(d00, vscale); + m01 = hn::Mul(d01, vscale); + m10 = hn::Mul(d10, vscale); + m11 = hn::Mul(d11, vscale); + m20 = hn::Mul(d20, vscale); + m21 = hn::Mul(d21, vscale); + m30 = hn::Mul(d30, vscale); + m31 = hn::Mul(d31, vscale); + } + // Note that Stream is neutral on SKX and harmful on Zen4. + hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c); + hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND); + hn::Store(hn::DemoteTo(df, m10), df, cr1 + col_c); + hn::Store(hn::DemoteTo(df, m11), df, cr1 + col_c + ND); + hn::Store(hn::DemoteTo(df, m20), df, cr2 + col_c); + hn::Store(hn::DemoteTo(df, m21), df, cr2 + col_c + ND); + hn::Store(hn::DemoteTo(df, m30), df, cr3 + col_c); + hn::Store(hn::DemoteTo(df, m31), df, cr3 + col_c + ND); + } + } + + for (; col_c < range_nc.end(); col_c += ND) { + const size_t remaining = range_nc.end() - col_c; + HWY_DASSERT(remaining < 2 * ND); + + VD a0; // unused if !kAdd + if constexpr (kAdd) { + // Promoting to double lets us fuse the Add into MulAdd. + a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining)); + } + const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining); + const VD d10 = hn::LoadN(dd, pr1 + col_c, remaining); + const VD d20 = hn::LoadN(dd, pr2 + col_c, remaining); + const VD d30 = hn::LoadN(dd, pr3 + col_c, remaining); + VD m00, m10, m20, m30; + if constexpr (kAdd) { + m00 = hn::MulAdd(d00, vscale, a0); + m10 = hn::MulAdd(d10, vscale, a0); + m20 = hn::MulAdd(d20, vscale, a0); + m30 = hn::MulAdd(d30, vscale, a0); + } else { + m00 = hn::Mul(d00, vscale); + m10 = hn::Mul(d10, vscale); + m20 = hn::Mul(d20, vscale); + m30 = hn::Mul(d30, vscale); + } + hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining); + hn::StoreN(hn::DemoteTo(df, m10), df, cr1 + col_c, remaining); + hn::StoreN(hn::DemoteTo(df, m20), df, cr2 + col_c, remaining); + hn::StoreN(hn::DemoteTo(df, m30), df, cr3 + col_c, remaining); + } + } + + // Same as above but handles a single row (for remainder rows). + template + static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, + const MMArgs& args) { + const hn::ScalableTag dd; + const hn::Rebind df; // result of DemoteTo + using VD = hn::Vec; + const size_t ND = hn::Lanes(dd); + const VD vscale = hn::Set(dd, args.scale); + + const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); + float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0); + + // We manually unroll 2x for higher IPC in batch=1. + size_t col_c = range_nc.begin(); + if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { + HWY_UNROLL(1) + for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { + VD a0, a1; // unused if !kAdd + if constexpr (kAdd) { + // Promoting to double lets us fuse the Add into MulAdd. + a0 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c)); + a1 = hn::PromoteTo(dd, hn::Load(df, args.add + col_c + ND)); + } + const VD d00 = hn::Load(dd, pr0 + col_c); + const VD d01 = hn::Load(dd, pr0 + col_c + ND); + VD m00, m01; + if constexpr (kAdd) { + m00 = hn::MulAdd(d00, vscale, a0); + m01 = hn::MulAdd(d01, vscale, a1); + } else { + m00 = hn::Mul(d00, vscale); + m01 = hn::Mul(d01, vscale); + } + // Note that Stream is neutral on SKX and harmful on Zen4. + hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c); + hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND); + } + } + + for (; col_c < range_nc.end(); col_c += ND) { + const size_t remaining = range_nc.end() - col_c; + HWY_DASSERT(remaining < 2 * ND); + + VD a0; // unused if !kAdd + if constexpr (kAdd) { + // Promoting to double lets us fuse the Add into MulAdd. + a0 = hn::PromoteTo(dd, hn::LoadN(df, args.add + col_c, remaining)); + } + const VD d00 = hn::LoadN(dd, pr0 + col_c, remaining); + VD m00; + if constexpr (kAdd) { + m00 = hn::MulAdd(d00, vscale, a0); + } else { + m00 = hn::Mul(d00, vscale); + } + hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining); + } + } +}; // MMScaleDemoteAdd + +// Called on the main thread with the entire N range, or by each package with +// a static partition of N. This class contains several variants of the +// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. +// Its member variables avoid long argument lists in Do*(). +class MMPerPackage { + public: + template + MMPerPackage(const ConstMat& A, const MMArgs& args, + const MMConfig& config, size_t pkg_idx, + const IndexRange& range_np) + : args_(args), + pkg_idx_(pkg_idx), + range_np_(range_np), + mr_(config.MR()), + ranges_mc_(config.RangesOfMC(A.Extents().rows)), + ranges_kc_(config.RangesOfKC(A.Extents().cols)), + ranges_nc_(config.RangesOfNC(range_np)), + order_(config.Order()), + inner_tasks_(config.InnerTasks()), + out_(config.Out()) { + // May be overwritten with a view of A, if already BF16. + A_ = args_.env->storage.A(pkg_idx, A.Extents()); + { + MMZone zone; + zone.MaybeEnter("MM.DecompressA", args_); + A_ = DecompressA(A); + } + } + + // B is decompressed several call layers lower, but not all member functions + // depend on TB, so pass it as an argument instead of templating the class. + template + HWY_NOINLINE void operator()(const ConstMat& B) const { + // TODO: include NUQ tables? NumPacked in ConstMat? + const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows; + + switch (order_) { + case MMOrder::kNT: + return DoNT(B, num_packed_B); + case MMOrder::kNT_K: + return DoNT_K(B, num_packed_B); + case MMOrder::kNT_MT: + return DoNT_MT(B, num_packed_B); + case MMOrder::kNT_MT_K: + return DoNT_MT_K(B, num_packed_B); + default: + HWY_UNREACHABLE; + } + } + + private: + // Compute size of per-worker storage for `kNR` row ranges of B. Stack + // allocation avoids passing a worker index. + static constexpr size_t B_stride_max_ = + StrideForCyclicOffsets(MMStorage::kMaxKC); + static constexpr size_t B_storage_max_ = + kNR * B_stride_max_ + Allocator::MaxQuantumBytes() / sizeof(BF16); + + // Granularity of `ForNP`. B rows produce C columns, so we + // want a multiple of the line size to prevent false sharing. + static size_t MultipleNP() { + return HWY_MAX(kNR, Allocator::LineBytes() / sizeof(float)); + } + + // Single M and K, parallel N. Fills all of C directly. + template + HWY_INLINE void DoNT(const ConstMat& B, size_t num_packed_B) const { + MMZone zone; + zone.MaybeEnter("MM.NT", args_); + HWY_DASSERT(ranges_mc_.NumTasks() == 1); + HWY_DASSERT(ranges_kc_.NumTasks() == 1); + const IndexRange& range_M = ranges_mc_.Range(0); + const IndexRange& range_K = ranges_kc_.Range(0); + const size_t K = range_K.Num(); + const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K); + const size_t B_stride = StrideForCyclicOffsets(K); + + // Similar to `loop_nc` below, but here we hoisted `A_view`. + args_.env->parallel.ForNP( + range_np_, MultipleNP(), inner_tasks_, pkg_idx_, + [&](const IndexRange& range_nc) HWY_ATTR { + HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS + const RowPtrBF B_view(B_storage, K, B_stride); + + for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); + row_b += kNR) { + { + MMZone zone; + zone.MaybeEnter("MM.NT.DecB", args_); + DecompressB(B, num_packed_B, row_b, range_K, B_view); + } + MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), + args_); + } + }); + + HWY_DASSERT(out_ == MMOut::kDirect); // already filled C + } + + // Single M, parallel N, sequential K. Fills all of partial. + template + HWY_INLINE void DoNT_K(const ConstMat& B, size_t num_packed_B) const { + MMZone zone; + zone.MaybeEnter("MM.NT_K", args_); + HWY_DASSERT(ranges_mc_.NumTasks() == 1); + const IndexRange& range_mc = ranges_mc_.Range(0); + + // Loop over NC/MC/KC, called from the outer loops over K/N. + // C++14 generic lambda enables hoisting branches via template + // argument, while also capturing to avoid long argument lists. + const auto loop_nc = [&](BF16* B_storage, const IndexRange& range_kc, + const IndexRange& range_nc, + auto out_tag) HWY_ATTR { + const size_t kc = range_kc.Num(); + const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); + const RowPtrBF B_view(B_storage, kc, StrideForCyclicOffsets(kc)); + + for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); + row_b += kNR) { + { + MMZone zone; + zone.MaybeEnter("MM.NT_K.DecB", args_); + DecompressB(B, num_packed_B, row_b, range_kc, B_view); + } + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, + args_); + } + }; + + args_.env->parallel.ForNP( + range_np_, MultipleNP(), inner_tasks_, pkg_idx_, + [&](const IndexRange& range_nc) HWY_ATTR { + HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS + + // Peel off the first iteration of the kc loop: avoid + // zero-initializing `partial` by writing into it. + ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { + loop_nc(B_storage, range_kc, range_nc, MMSetPartial()); + }); + ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { + loop_nc(B_storage, range_kc, range_nc, MMAddPartial()); + }); + }); + + MMZone fill_zone; + if (out_ == MMOut::kCopy) { + fill_zone.MaybeEnter("MM.NT_K.FillC", args_); + MMScaleDemoteAdd::FillC(range_mc, range_np_, args_); + } else if (out_ == MMOut::kParM) { + fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_); + args_.env->parallel.ForRangeMC( + range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR { + MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, + args_); + }); + } else { + HWY_UNREACHABLE; // kDirect is only used with kNT. + } + } + + // Parallel loops over mc/nc blocks of M/range_np, single K. + // Fills `mc x nc` sections of C directly, in parallel. + template + HWY_INLINE void DoNT_MT(const ConstMat& B, size_t num_packed_B) const { + MMZone zone; + zone.MaybeEnter("MM.NT_MT", args_); + HWY_DASSERT(ranges_kc_.NumTasks() == 1); + const IndexRange& range_K = ranges_kc_.Range(0); + const size_t K = range_K.Num(); + + const size_t B_stride = StrideForCyclicOffsets(K); + + // Sequential loop over NC/MC/KC, similar to `loop_nc` below + // except for the profiler strings and `out_tag`. + args_.env->parallel.ForRangesMC_NC( + ranges_mc_, ranges_nc_, pkg_idx_, + [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { + const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K); + HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS + const RowPtrBF B_view(B_storage, K, B_stride); + + for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); + row_b += kNR) { + { + MMZone zone; + zone.MaybeEnter("MM.NT_MT.DecB", args_); + DecompressB(B, num_packed_B, row_b, range_K, B_view); + } + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), + args_); + } + }); + + HWY_DASSERT(out_ == MMOut::kDirect); // already filled C + } + + // Parallel loops over mc/nc blocks of M/range_np, sequential K. + // Fills `mc x nc` sections of `partial`, then `C`, in parallel. + template + HWY_INLINE void DoNT_MT_K(const ConstMat& B, size_t num_packed_B) const { + MMZone zone; + zone.MaybeEnter("MM.NT_MT_K", args_); + const size_t kc_max = ranges_kc_.TaskSize(); + HWY_DASSERT(kc_max <= MMStorage::kMaxKC); + const size_t B_stride = StrideForCyclicOffsets(kc_max); + // Sequential loop over NC/MC/KC, for when the M/N loops are + // already parallel. This is B3A2C0 in MOMMS terminology: we read + // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. + const auto loop_nc = [&](const RowPtrBF& B_view, const IndexRange& range_mc, + const IndexRange& range_kc, + const IndexRange& range_nc, + auto out_tag) HWY_ATTR { + const size_t kc = range_kc.Num(); + const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); + + for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); + row_b += kNR) { + { + MMZone zone; + zone.MaybeEnter("MM.NT_MT_K.DecB", args_); + DecompressB(B, num_packed_B, row_b, range_kc, B_view); + } + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, + args_); + } + }; // loop_nc + args_.env->parallel.ForRangesMC_NC( + ranges_mc_, ranges_nc_, pkg_idx_, + [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { + HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS + const RowPtrBF B_view(B_storage, kc_max, B_stride); + + // Peel off the first iteration of the kc loop: avoid + // zero-initializing `partial` by writing into it. + ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { + loop_nc(B_view, range_mc, range_kc, range_nc, MMSetPartial()); + }); + ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { + loop_nc(B_view, range_mc, range_kc, range_nc, MMAddPartial()); + }); + + // Already in parallel section, hence no `kParM`, and + // `kDirect` is only used with `kNT_MT`. + HWY_DASSERT(out_ == MMOut::kCopy); + MMZone fill_zone; + fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_); + MMScaleDemoteAdd::FillC(range_mc, range_nc, args_); + }); + } + + // Decompresses all `M x K` from `A` into `pkg_A`. Assumes `TA` is a seekable + // type (i.e., not SFP/NUQ) so we can use pointer arithmetic. + template + HWY_NOINLINE void DoDecompressA(const ConstMat& A, MMParA par_a) const { + const IndexRange all_M(0, A.extents.rows); + const IndexRange all_K(0, A.extents.cols); + HWY_DASSERT(all_K.Num() == A_.Cols()); + + const hn::ScalableTag dbf; + const size_t NBF = hn::Lanes(dbf); + static_assert(hwy::IsSameEither(), "Can seek"); + + const auto do_range = [&](const IndexRange& range_M, + const IndexRange& range_K) HWY_ATTR { + const size_t col0 = range_K.begin(); + const size_t cols = range_K.Num(); + for (size_t row_a : range_M) { + const PackedSpan from = + MakeSpan(A.ptr + A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_.Row(row_a) + col0; + DecompressAndZeroPad(dbf, from, 0, to, cols); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + }; + + switch (par_a) { + case MMParA::kNone: + do_range(all_M, all_K); + break; + case MMParA::kK1: + case MMParA::kK2: + case MMParA::kK4: { + const size_t inner_tasks = static_cast(par_a); + // At least one vector, otherwise DecompressAndZeroPad will add + // padding, which might overwrite neighboring tasks. Also a whole cache + // line to avoid false sharing. + const size_t multiple_K = + HWY_MAX(NBF, Allocator::LineBytes() / sizeof(BF16)); + + args_.env->parallel.ForNP( + all_K, multiple_K, inner_tasks, pkg_idx_, + [&](const IndexRange& range_K) { do_range(all_M, range_K); }); + break; + } + case MMParA::kM: + args_.env->parallel.ForRangeMC(all_M, pkg_idx_, [&](size_t row_a) { + do_range(IndexRange(row_a, row_a + 1), all_K); + }); + break; + } + } + + // Autotuning wrapper for `DoDecompressA`. + template + HWY_INLINE RowPtrBF DecompressA(const ConstMat& A) const { + MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; + // If already BF16, maybe return a view: + if constexpr (hwy::IsSame()) { + // Only if no zero-padding required. + const size_t NBF = hn::Lanes(hn::ScalableTag()); + if (HWY_LIKELY(A.extents.cols % NBF == 0)) { + const BF16* pos = A.ptr + A.Row(0); + return RowPtrBF(const_cast(pos), A.extents.cols, A.Stride()); + } + } + + if (HWY_LIKELY(autotune.Best())) { + DoDecompressA(A, *autotune.Best()); + return A_; + } + + // First call: generate candidates. + if (HWY_UNLIKELY(!autotune.HasCandidates())) { + std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4}; + if (A.extents.rows == 1) { + candidates.push_back(MMParA::kNone); + } else { + candidates.push_back(MMParA::kM); + } + autotune.SetCandidates(candidates); + } + + const MMParA& par_a = autotune.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DoDecompressA(A, par_a); + const uint64_t t1 = + args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); + if (HWY_UNLIKELY(args_.env->print_measurement && autotune.ShouldPrint())) { + fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), + static_cast(min_elapsed) / + hwy::platform::InvariantTicksPerSecond() * 1E6); + } + return A_; + } + + // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, + // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` + // thanks to its large table lookups, and less so on other targets. + template + HWY_INLINE void DecompressB(const ConstMat& B, size_t num_packed_B, + const size_t row_b, const IndexRange& range_kc, + const RowPtrBF& B_view) const { + const hn::ScalableTag dbf; + + const PackedSpan B_span = MakeSpan(B.ptr, num_packed_B); + + const size_t kc = range_kc.Num(); + const size_t col0 = range_kc.begin(); + + for (size_t r = 0; r < kNR; ++r) { + const size_t packed_ofs = B.Row(row_b + r) + col0; + BF16* HWY_RESTRICT to = B_view.Row(r); + DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, hn::Lanes(dbf)); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + } + + const MMArgs args_; // copy for locality + const size_t pkg_idx_; + RowPtrBF A_; // points into A or storage. + + const IndexRange range_np_; + // From MMConfig: + const size_t mr_; + const IndexRangePartition ranges_mc_; + const IndexRangePartition ranges_kc_; + const IndexRangePartition ranges_nc_; + const MMOrder order_; + const size_t inner_tasks_; + const MMOut out_; +}; // MMPerPackage + +// Stateless, wraps member functions. +struct MMImpl { + // Returns existing entry for the given key or -1. + static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { + const hwy::Span all_keys = keys.Keys(); + // TODO: SIMD scan + for (size_t i = 0; i < all_keys.size(); ++i) { + if (all_keys[i] == key) return static_cast(i); + } + return -1; + } + + // Called from `MatMul` from two places: either with the next autotune config, + // or with the best config. + template + static HWY_NOINLINE void DoMatMul(const ConstMat& A, + const ConstMat& B, const MMArgs& args, + const MMConfig& config) { + MMZone matmul_zone; + matmul_zone.MaybeEnter("MM.DoMatMul", args); + + // Outermost loop: static NUMA-aware partition of B rows across packages. + args.env->parallel.ForPkg( + args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { + const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); + MMPerPackage(A, args, config, pkg_idx, range_np)(B); + }); + } +}; // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // -// `A` is a row-major matrix and `B` is transposed. Its `B.Extents().cols`, -// which must match `A.Extents().cols`, is the number of rows in the original B. +// `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's +// `K = B.Extents().cols`, which must match `A.Extents().cols`, is the number +// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There +// are no other restrictions on shape, though performance is better when `M % 4 +// == 0` or `M <= 4`. +// +// If `add` is non-null, the row-vector `add` is added to each of the `M` rows +// of `C`, which is a row-major matrix with arbitrary stride. A scale for +// `add` is not supported, so make sure its scale is 1. // -// If `add` is non-null, the row-vector `add` is added to each row of `C`. -// A scale for `add` is not supported, so make sure its scale is 1. +// Must not be called concurrently with the same `env`. The first few calls +// for a given shape will try different configs. The best is recorded in `env` +// and will be used for subsequent calls with that shape. // -// `C` is a row-major matrix of size `(A.rows, C.Cols())` with support for -// arbitrary strides. +// Returns the (autotuning) state for the current shape. This pointer may be +// invalidated by the next call to `MatMul`. // -// Updates 4x4 tiles of C in parallel using a work-stealing thread pool. -// Typically `A.rows` is 1..512, `A.Extents().cols` and `B.Extents().rows` are -// 3k or 24k. Must not be called concurrently with the same `env`. -template -HWY_NOINLINE void MatMul(const ConstMat& A, const ConstMat& B, - const float* HWY_RESTRICT add, MatMulEnv& env, - const RowPtr& C) { - if (add) { - MatMulImpl(A, B, add, env, C); - } else { - MatMulImpl(A, B, nullptr, env, C); +// Uses considerable stack space: at least 40 KiB per thread. +template +HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + const RowPtrF& C) { + const size_t M = A.Extents().rows; + const size_t K = A.Extents().cols; + const size_t N = B.Extents().rows; + const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); + intptr_t index = MMImpl::IndexOfKey(key, env.keys); + // First time we see this shape/key. + if (HWY_UNLIKELY(index < 0)) { + env.keys.Append(key); + + size_t max_packages = MMParallel::kMaxPackages; + // For low-batch, multiple sockets only help if binding is enabled. + if (!Allocator::ShouldBind() && M <= 4) { + static std::atomic_flag once = ATOMIC_FLAG_INIT; + if (!once.test_and_set()) { + HWY_WARN( + "Multiple sockets but binding disabled. Low-batch MatMul is only " + "using a single socket."); + } + max_packages = 1; + } + + // invalidates `MMAutoTune::Best()` + index = env.per_key.size(); + env.per_key.push_back(MMPerKey(max_packages, N, kNR, env.parallel)); + } + MMPerKey& per_key = env.per_key[index]; + MMAutoTune& tuner = per_key.autotune; + + const MMArgs args(env, per_key, static_cast(A.scale) * B.scale, add, + env.storage.Partial(), C); + if (HWY_LIKELY(tuner.Best())) { + MMImpl::DoMatMul(A, B, args, *tuner.Best()); + return &per_key; } + + PROFILER_ZONE("Matmul.Autotune"); + + // First call: enumerate all feasible configs. + if (HWY_UNLIKELY(!tuner.HasCandidates())) { + // Ensure matrix dimensions match each other. + HWY_ASSERT(K == B.Extents().cols); + HWY_ASSERT(N == C.Cols()); + HWY_ASSERT(M <= MMStorage::kMaxM); + HWY_ASSERT(K <= MMStorage::kMaxK); + HWY_ASSERT(N <= MMStorage::kMaxN); + HWY_ASSERT(N % kNR == 0); + + // Negligible CPU time. + tuner.SetCandidates(MMCandidates(M, K, N, MMKernel::kMaxMR, kNR, + per_key.ranges_np, env.print_config)); + } + + const MMConfig& cfg = tuner.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + MMImpl::DoMatMul(A, B, args, cfg); + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / + hwy::platform::InvariantTicksPerSecond(); + const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA + if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { + fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s\n", flops * 1E-9, + min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), + StringFromOrder(cfg.Order()), cfg.InnerTasks(), + StringFromOut(cfg.Out())); + } + if (HWY_UNLIKELY(env.print_best && tuner.Best())) { + const auto ratio = [per_key](uint64_t ticks) -> double { + return static_cast(ticks) / + static_cast(per_key.autotune.BestTicks()); + }; + const MMConfig& best = *tuner.Best(); + fprintf(stderr, + "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%s,%.2f,%.2f\n", + M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), + best.KC(), best.NC(), StringFromOrder(best.Order()), + best.InnerTasks(), StringFromOut(best.Out()), + ratio(tuner.WorstMinTicks()), ratio(tuner.FirstConfigTicks())); + } + + return &per_key; } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/matmul.cc b/ops/matmul.cc new file mode 100644 index 0000000..80f1d8d --- /dev/null +++ b/ops/matmul.cc @@ -0,0 +1,415 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ops/matmul.h" + +// Analytical model of cache parameters for generating autotune candidates. + +#include +#include +#include + +#include + +#include "util/allocator.h" +#include "util/basics.h" +#include "util/threading.h" +#include "hwy/base.h" +#include "hwy/detect_targets.h" +#include "hwy/per_target.h" +#include "hwy/timer.h" + +namespace gcpp { +namespace { + +// Rounds down to a multiple of `multiple`, but returns at least `multiple`. +size_t RoundDownWithFloor(size_t value, size_t multiple) { + HWY_DASSERT(multiple != 0); + return HWY_MAX(multiple, hwy::RoundDownTo(value, multiple)); +} + +// Returns the highest number in `[begin, end)` that divides `dim` and is a +// multiple of `multiple`, or 0 if none exists. +size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, + const size_t multiple) { + HWY_DASSERT(end != 0 && dim != 0 && multiple != 0); + size_t prev = RoundDownWithFloor(end, multiple); + // Avoid returning `end` if rounding down had no effect. + if (prev == end) prev -= multiple; + for (;;) { + if (prev == 0) return 0; // No divisor if large multiple or small end. + if (dim % prev == 0) return prev; + if (prev <= begin) return 0; + prev -= multiple; + } +} + +// Implementation of `MMCandidates`. Class hides the `KC` etc member functions +// and holds most of their arguments in member variables. +class GenerateCandidates { + public: + GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr, + const IndexRangePartition& ranges_np, bool print_config) + : M_(M), + K_(K), + max_mr_(max_mr), + nr_(nr), + // These influence kc/nc, but are also stored in `MMConfig` for + // `RangesOf*`. Must be a vector multiple. The previous/next cache line + // is likely still in L1, but we expect K > 1000 and might as well round + // up to the line size. + kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))), + nc_multiple_(Allocator::StepBytes() / sizeof(float)), + ranges_np_(ranges_np), + print_config_(print_config) {} + + std::vector operator()() const { + std::vector candidates; + candidates.reserve(128); + + for (size_t mr : MR()) { + for (MMOrder order : Orders(mr)) { + const std::vector& all_inner_tasks = InnerTasks(order); + const std::vector& all_outs = Outs(order); + for (size_t kc : KC(mr, order)) { + for (size_t mc : MC(mr, kc, order)) { + for (size_t nc : NC(mr, mc, kc, order)) { + for (int inner_tasks : all_inner_tasks) { + for (MMOut out : all_outs) { + const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_, + nc_multiple_, order, out, inner_tasks); + const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); + const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); + + // Blocks only make sense when there are multiple M tasks. + if (IsBlock(order) != (M_tasks > 1)) continue; + // Single KC only makes sense when there is a single K task. + if (IsOneKC(order) != (K_tasks == 1)) continue; + + candidates.push_back(config); + } + } + } + } + } + } + } + + HWY_ASSERT(!candidates.empty()); + return candidates; + } + + private: + using SizeVec = std::vector; + + // How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may + // be better for SIMD targets with fewer registers. + SizeVec MR() const { + const int64_t target = hwy::DispatchedTarget(); + const bool is_avx2 = target == HWY_AVX2; + const bool is_sse = HWY_SSE4 <= target && target <= HWY_SSE2; + const bool is_wasm = target == HWY_WASM || target == HWY_WASM_EMU256; + + SizeVec all_mr; + all_mr.reserve(3); + // AVX2's 16 registers are not enough for four rows, but SSE4 may benefit. + if (M_ >= max_mr_ && !is_avx2) all_mr.push_back(max_mr_); + // Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also + // enable if not enough rows for 4. + if (M_ >= 2 && (M_ < max_mr_ || (!is_sse && !is_wasm))) { + all_mr.push_back(size_t{2}); + } + // Even SSE4 usually prefers 2 rows; only enable for single rows. + if (M_ == 1) all_mr.push_back(size_t{1}); + HWY_ASSERT(!all_mr.empty()); + return all_mr; + } + + // Which loop orders to enable depending on M. + std::vector Orders(size_t mr) const { + std::vector orders; + for (size_t order_idx = 0;; ++order_idx) { + const MMOrder order = static_cast(order_idx); + if (StringFromOrder(order) == nullptr) return orders; // done + // 2D blocking is useless for a single row of M. + if (IsBlock(order) && M_ <= mr) continue; + // Conversely, N-only parallelism is uncompetitive for large M. + if (!IsBlock(order) && M_ >= 8 * mr) continue; + orders.push_back(order); + } + } + + // The number of A and B columns to read between updating `partial`. + SizeVec KC(size_t mr, MMOrder order) const { + // `LoopOverKC` handles up to `mr` rows of A. + const size_t rows_a = HWY_MIN(M_, mr); + + // After looping over `kc` columns, we write `mr x 4` outputs and 16 vector + // `buf`. To amortize the write cost, we want to maximize `kc`. However, it + // is important that B fits in L1, because batch=1 only has a single row of + // A and thus no reuse of the packed B. When L1-resident, we can use the + // separate `DecompressAndZeroPad` to write `kc` columns, rather than having + // to integrate `Decompress2` into `LoopOverKC`, which is less efficient for + // TB=NUQ due to less amortization of the table loads. Due to the low L1 + // latency, the packing is still effectively fused into `LoopOverKC`. It may + // be better to round up and accept a few L2 accesses in exchange for + // fewer loops over K, and thus fewer writes to `partial`. Hence we do not + // subtract the output and buf, and allow using more than the actual L1 + // size. This results in an overestimate, and the loop below will propose + // the next few smaller values for the autotuner to evaluate. + const size_t bytes_ab = Allocator::L1Bytes() * 3; + const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); + size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); + kc_max = + RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); + kc_max = HWY_MIN(kc_max, K_); + + SizeVec all_kc(1, kc_max); + + // Avoid proposing kc > K. + if (K_ > kc_multiple_) { + // Generally it is best to use the full `kc` (fewer writes to `partial`), + // but a bit less can be better if it evenly divides `K`, or enables an + // `mc` that evenly divides `M`. Try several smaller values. + + // If we can afford a single K task, that's usually best; only try one + // more. Otherwise, blocks may require smaller kc (more options). + const size_t reps = (kc_max == K_) ? 1 : IsBlock(order) ? 3 : 2; + + size_t prev = kc_max; + for (size_t rep = 0; rep < reps; ++rep) { + const size_t div = PrevDivisor(kc_multiple_, prev, K_, kc_multiple_); + prev = div ? div : RoundDownWithFloor(prev / 2, kc_multiple_); + all_kc.push_back(prev); + } + } + + if (print_config_ && all_kc.size() > 1) { + fprintf(stderr, "KC: "); + for (size_t kc : all_kc) { + fprintf(stderr, "%zu ", kc); + } + fprintf(stderr, "\n"); + } + + return all_kc; + } + + // The number of (L2 resident) A rows for `A2C0` to loop over. + SizeVec MC(size_t mr, size_t kc, MMOrder order) const { + // Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because + // it is typically inclusive. + const size_t bytes_b = nr_ * kc * (sizeof(SfpStream) + sizeof(BF16)); + + // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the + // packed B. We want `mc * kc` elements of A to fit in L2, alongside + // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of + // partial. + const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes(); + size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc); + mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); + HWY_DASSERT(mc_max != 0); + mc_max = HWY_MIN(mc_max, M_); + mc_max = hwy::RoundDownTo(mc_max, mr); + + SizeVec all_mc(1, mc_max); + // Larger MC is better for non-blocks, otherwise we want more small options. + const size_t reps = !IsBlock(order) ? 2 : 3; + + size_t prev = mc_max; + for (size_t rep = 0; rep < reps; ++rep) { + prev = PrevDivisor(1, prev, M_, mr); + if (prev >= mc_max || prev == 0) break; + all_mc.push_back(prev); + } + + // Blocks: largest is not useful. + if (IsBlock(order) && all_mc.size() > 1) { + all_mc.erase(all_mc.begin(), all_mc.begin() + 1); + } + + if (print_config_ && all_mc.size() > 1) { + fprintf(stderr, "MC: "); + for (size_t mc : all_mc) { + fprintf(stderr, "%zu ", mc); + } + fprintf(stderr, "\n"); + } + + return all_mc; + } + + // The number of (possibly L3 resident) B rows per `NT_MT` task. + SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { + const size_t np_max = ranges_np_.TaskSize(); + size_t nc_max = np_max; + const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double); + // Only if there will be reuse of B: choose the largest `nc_max` (C cols) + // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. + // Otherwise, leave it unbounded. + if (M_ > mr) { + const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes); + nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc); + nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max); + } + HWY_DASSERT(nc_max != 0); + nc_max = RoundDownWithFloor(nc_max, nc_multiple_); + + // If there are going to be multiple ranges, anything more than half would + // be imbalanced and suboptimal. + if (nc_max < np_max && nc_max >= np_max / 2) { + nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_); + } + + // Non-block calls ForNP, which ignores `range_nc` and uses `range_np`. + if (!IsBlock(order)) return SizeVec(1, np_max); + + SizeVec all_nc(1, nc_max); + + // Avoid proposing nc > N. + if (np_max > nc_multiple_) { + // Large L3, but its behavior and characteristics varies across platforms, + // hence autotune a wider range of nc than the other dimensions. + size_t reps = 10; + // For small M, we can afford larger NC, hence allow fewer small options. + if (M_ <= 2 * mr) reps -= 1; + + size_t prev = nc_max; + for (size_t rep = 0; rep < reps; ++rep) { + const size_t div = + PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_); + prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_); + all_nc.push_back(prev); + if (prev == nc_multiple_) break; + } + + // Skip the larger values (unlikely to be chosen), keep about 40%. + const ptrdiff_t want_delete = + static_cast(all_nc.size() * 5 / 9 + 2); + // Keep at least 2. + const ptrdiff_t max_delete = + HWY_MAX(static_cast(all_nc.size()) - 2, ptrdiff_t{0}); + all_nc.erase(all_nc.begin(), + all_nc.begin() + HWY_MIN(want_delete, max_delete)); + } + + if (print_config_ && all_nc.size() > 1) { + fprintf(stderr, "NC: "); + for (size_t nc : all_nc) { + fprintf(stderr, "%zu ", nc); + } + fprintf(stderr, "\n"); + } + + return all_nc; + } + + // How many tasks per cluster worker. More = smaller tasks, which can lead + // to better load balancing at the cost of higher overhead. + std::vector InnerTasks(MMOrder order) const { + std::vector inner_tasks; + inner_tasks.reserve(3); + inner_tasks.push_back(1); + // Blocks have one task per mc/nc range and ignore this parameter. + if (!IsBlock(order)) { + inner_tasks.push_back(2); + inner_tasks.push_back(4); + } + return inner_tasks; + } + + // Whether to parallelize FillC or enable direct writes to C. + std::vector Outs(MMOrder order) const { + std::vector outs; + for (size_t out_idx = 0;; ++out_idx) { + const MMOut out = static_cast(out_idx); + if (StringFromOut(out) == nullptr) return outs; // done + // kParM only makes sense if we have more than one row of A. + if (out == MMOut::kParM && M_ == 1) continue; + // Blocks are already parallelized. + if (out == MMOut::kParM && IsBlock(order)) continue; + // Direct only works for a single kc range. + if ((out == MMOut::kDirect) != IsOneKC(order)) continue; + // For non-block, kCopy does not beat kDirect. + if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue; + outs.push_back(out); + } + } + + const size_t M_; + const size_t K_; + + const size_t max_mr_; + const size_t nr_; + + const size_t kc_multiple_; + const size_t nc_multiple_; + + IndexRangePartition ranges_np_; + + const bool print_config_; +}; + +} // namespace + +// Facade to avoid exposing `GenerateCandidates` in the header. +std::vector MMCandidates(size_t M, size_t K, size_t N, size_t max_mr, + size_t nr, + const IndexRangePartition& ranges_np, + bool print_config) { + return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)(); +} + +// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote +// memory accesses or false sharing, unless there are insufficient per-package +// rows for that. +static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) { + size_t np_multiple = Allocator::QuantumBytes() / sizeof(float); + // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For + // `N` < 4096, this can cause significant load imbalance. If split unevenly, + // choose a smaller multiple. + if (N % (np_multiple * num_packages)) { + const size_t min_multiple = Allocator::LineBytes() / sizeof(float); + np_multiple = + PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple); + if (HWY_UNLIKELY(np_multiple == 0)) { + np_multiple = min_multiple; + } + // This happens in tests with small N, hence do not assert. + if (N % (np_multiple * num_packages) && N >= 128) { + HWY_WARN("NPMultiple: N=%zu still not divisible by np_multiple=%zu\n", N, + np_multiple); + np_multiple = nr; + } + } + return np_multiple; +} + +IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, + size_t nr) const { + const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages()); + return StaticPartition(IndexRange(0, N), num_packages, + NPMultiple(N, nr, num_packages)); +} + +MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) { + // Ensure Allocator:Init was called. + HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0); + + char cpu100[100]; + have_timer_stop = hwy::platform::HaveTimerStop(cpu100); +} + +} // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index c1c6d44..707e37b 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -16,86 +16,667 @@ #ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ #define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ +// Non-SIMD part of MatMul: parallelization, allocation, and autotuning. + #include +#include + +#include // IWYU pragma: begin_exports #include "compression/compress.h" #include "util/allocator.h" #include "util/basics.h" #include "util/threading.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" +#include "hwy/bit_set.h" #include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/profiler.h" // IWYU pragma: end_exports -#include "hwy/per_target.h" // VectorBytes - namespace gcpp { -// TODO: remove deprecated typedef. -using Range1D = IndexRange; - // The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of // loads, we reuse the same A row for several B columns, which are also loaded // once for several rows of C. Thus we produce one 'tile' of C at a time of -// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are -// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4 -// enables the `StoreInterleaved4` transpose in `StoreHorizontalSums`. We assume -// and verify that `C.Cols() % kRegCols == 0`. -constexpr size_t kRegCols = 4; - -// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because -// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile. -// In general, `batch_size` (A/C rows) is not a multiple of `kRegRows`. Thus -// functions that load or store a tile are parameterized on `kRowsPerTile`: -// usually `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). -constexpr size_t kRegRows = kRegCols; - -struct CacheSizes { - CacheSizes() = default; - CacheSizes(const BoundedTopology::Cluster& cluster) { - // Assumes each package and cluster has the same cache sizes, and uses - // reasonable defaults if unknown. - l1_bytes = 32 * 1024; // typical size, rarely changes - l2_bytes = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) * 1024; - l3_bytes = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) * 1024; - } - - size_t l1_bytes; - size_t l2_bytes; - size_t l3_bytes; -}; +// dimensions `mr (<= kMaxMR)` x `kNR`. To keep FMA units busy, this should be +// at least the product of the FMA latency (3..5) times the throughput (2). +// This and `mr` are limited by the number of registers, which is generally +// 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in +// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`. +constexpr size_t kNR = 4; class MMParallel { public: - MMParallel() : pools_(nullptr) {} - explicit MMParallel(NestedPools& pools) : pools_(&pools) {} + static constexpr size_t kMaxPackages = 4; + + MMParallel(NestedPools& pools) : pools_(pools) { + HWY_DASSERT(pools_.NumPackages() <= kMaxPackages); + } + + // Used by tests. + NestedPools& Pools() { return pools_; } + + // Initial static partitioning of B rows across packages. + IndexRangePartition RangesOfNP(size_t max_packages, size_t N, + size_t nr) const; + + // For `BindB` and `BindC`. + size_t Node(size_t pkg_idx) const { + return pools_.Topology().GetCluster(pkg_idx, 0).Node(); + } + + // Calls `func(pkg_idx)` for each package in parallel. + template + void ForPkg(const size_t max_packages, const Func& func) { + pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()), + [&](uint64_t task, size_t pkg_idx) { + HWY_DASSERT(task == pkg_idx); + (void)task; + func(pkg_idx); + }); + } + + // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is + // the granularity of per-cluster tasks. Calls `func(worker_range)`. + template + void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, + size_t pkg_idx, const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + // Single cluster: parallel-for over static partition of `range_np`. + hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx); + const size_t num_clusters = all_clusters.NumWorkers(); + if (num_clusters == 1) { + hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0); + const IndexRangePartition worker_ranges = StaticPartition( + range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + return ParallelizeOneRange( + worker_ranges, cluster, + [&](const IndexRange& worker_range, size_t /*thread*/) { + func(worker_range); + }); + } + + // Assign each cluster a sub-range of `range_np` (typically hundreds). + const IndexRangePartition nx_ranges = + StaticPartition(range_np, num_clusters, nx_multiple); + ParallelizeOneRange( + nx_ranges, all_clusters, + [&](const IndexRange& nx_range, const size_t cluster_idx) { + hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx); + // Parallel-for over sub-ranges of `cluster_range` within the cluster. + const IndexRangePartition worker_ranges = StaticPartition( + nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); + ParallelizeOneRange(worker_ranges, cluster, + [&](const IndexRange& worker_range, + size_t /*thread*/) { func(worker_range); }); + }); + } + + // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B + // rows). Calls `func(range_mc, range_nc)`. + template + void ForRangesMC_NC(const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, size_t pkg_idx, + const Func& func) { + hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx); + // `all_clusters` is a pool with one worker per cluster in a package. + const size_t num_clusters = all_clusters.NumWorkers(); + // Single (big) cluster: collapse two range indices into one parallel-for + // to reduce the number of fork-joins. + if (num_clusters == 1) { + const size_t cluster_idx = 0; + hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx); + // Low-batch: avoid Divide/Remainder. + if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { + return ParallelizeOneRange( + ranges_nc, cluster, + [&](const IndexRange& range_nc, size_t /*thread*/) { + func(ranges_mc.Range(0), range_nc); + }); + } else { + return ParallelizeTwoRanges( + ranges_mc, ranges_nc, cluster, + [&](const IndexRange& range_mc, const IndexRange& range_nc, + size_t /*thread*/) { func(range_mc, range_nc); }); + } + } - NestedPools& Pools() const { return *pools_; } - hwy::ThreadPool& Pool() const { return pools_->Pool(); } + // Multiple clusters: N across clusters (both are usually the larger), and + // M within each cluster. We assume auto-tuning finds small MC/NC tasks. + ParallelizeOneRange( + ranges_nc, all_clusters, + [&](const IndexRange range_nc, size_t cluster_idx) { + hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx); + ParallelizeOneRange( + ranges_mc, cluster, + [&](const IndexRange& range_mc, size_t /*thread*/) { + func(range_mc, range_nc); + }); + }); + } + + // Calls `func(row_a)` in parallel. + template + void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx, + const Func& func) { + pools_.Pool(pkg_idx).Run( + range_mc.begin(), range_mc.end(), + [&](uint64_t row_a, size_t /*thread*/) { func(row_a); }); + } private: - NestedPools* pools_; + NestedPools& pools_; }; -// Allocations and threads, shared across MatMul calls. -class MatMulEnv { +template // float for C, double for partial +void BindC(size_t M, const RowPtr& C, MMParallel& parallel) { + if (!Allocator::ShouldBind()) return; + + const IndexRangePartition ranges_np = + parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR); + const size_t quantum = Allocator::QuantumBytes() / sizeof(T); + bool ok = true; + for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { + const IndexRange& cols_c = ranges_np.Range(pkg_idx); + const size_t node = parallel.Node(pkg_idx); + for (size_t im = 0; im < M; ++im) { + // BindRowsToPackageNodes may not be page-aligned. + const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum); + const size_t end = hwy::RoundDownTo(cols_c.end(), quantum); + ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T), + node); + } + } + if (HWY_UNLIKELY(!ok)) { + HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", M, C.Cols(), + ranges_np.NumTasks()); + } +} + +// Per-package storage for packed A, and one global C-shaped `partial` for +// accumulating partial dot products (sections of K). +class MMStorage { public: - explicit MatMulEnv(NestedPools& pools) : parallel(pools) { - const size_t N = hwy::VectorBytes() / sizeof(float); - buf_ = RowVectorBatch(Extents2D(pools.MaxWorkers(), 16 * N)); + // Compile-time bounds on matrix dimensions to enable pre-allocating storage + // and reusing it across `MatMul` calls. The resulting allocations are 256 MiB + // per package and 512 MiB, respectively. + static constexpr size_t kMaxM = 2048; + static constexpr size_t kMaxK = 64 * 1024; + static constexpr size_t kMaxN = 256 * 1024; + // Upper bound for per-worker B storage on the stack. Chosen such that one row + // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. + static constexpr size_t kMaxKC = 8 * 1024; + + explicit MMStorage(MMParallel& parallel) { + // Per-package allocation so each can decompress A into its own copy. + parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { + pkg_A_[pkg_idx] = AllocateAlignedRows(Extents2D(kMaxM, kMaxK)); + + if (Allocator::ShouldBind()) { + const size_t node = parallel.Node(pkg_idx); + if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(), + pkg_A_[pkg_idx].NumBytes(), node)) { + HWY_WARN("Failed to bind memory for package %zu", pkg_idx); + } + } + }); + + // Per-worker copies of `partial` would be wasteful. We instead allocate + // one instance of the maximum matrix extents because threads write at + // false-sharing-free granularity. + partial_storage_ = AllocateAlignedRows(Extents2D(kMaxM, kMaxN)); + // Same stride independent of the actual C.Cols() so we can pre-bind. + partial_ = RowPtrD(partial_storage_.All(), kMaxN, + StrideForCyclicOffsets(kMaxN)); + // Avoid cross-package accesses. + BindC(kMaxM, partial_, parallel); + } + + // Returns per-package matrix view. Non-const so that `RowVectorBatch` is + // non-const, because `RowPtr` requires a non-const pointer. + RowPtrBF A(size_t pkg_idx, const Extents2D& extents) { + HWY_DASSERT(extents.rows <= kMaxM); + HWY_DASSERT(extents.cols <= kMaxK); + const size_t stride = StrideForCyclicOffsets(extents.cols); + return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride); + } + + RowPtrD Partial() const { return partial_; } + + private: + RowVectorBatch pkg_A_[MMParallel::kMaxPackages]; + RowVectorBatch partial_storage_; + RowPtrD partial_; +}; + +//------------------------------------------------------------------------------ +// Autotuning + +// Naming convention: outer loop first, T suffix means threaded. This refers to +// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost +// `ranges_np` loop across packages is implicit and applies to all of these. +// +// Parallelizing across K (A/B columns) is undesirable because the resulting +// partial dot products require synchronization or reduction across threads. +enum class MMOrder : uint8_t { + // Single M, parallel N, sequential K (inside the parallel section to + // reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K. + kNT_K, + // Specialization of `kNT_K` for a single K task with `kDirect`. + kNT, + + // Parallelize over blocks of M and N: good when both are large. We no longer + // support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as + // fast on Zen4. + kNT_MT_K, + kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`. + + // Resident C (`kK_M_NT`) should be good for large K relative to M and N. + // However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are + // no kN* because we expect M (batch size) to be small relative to K and N. +}; + +static inline bool IsBlock(MMOrder order) { + return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT; +} + +static inline bool IsOneKC(MMOrder order) { + return order == MMOrder::kNT || order == MMOrder::kNT_MT; +} + +static inline const char* StringFromOrder(MMOrder order) { + switch (order) { + case MMOrder::kNT_K: + return "NT_K"; + case MMOrder::kNT: + return "NT"; + case MMOrder::kNT_MT_K: + return "NT_MT_K"; + case MMOrder::kNT_MT: + return "NT_MT"; + default: + return nullptr; + } +} + +// How/where to write the A2C0 result. This determines the `tag` argument to +// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or +// `MMAddHorizontalSumsIntoPartial`. +enum class MMOut : uint8_t { + kCopy, // accumulate into partial, scale/add to C + kDirect, // single kc task, write directly to C + kParM // kCopy but parallel over M + // kParN is not better on SKX/Zen4. +}; + +static inline const char* StringFromOut(MMOut out) { + switch (out) { + case MMOut::kDirect: + return "Direct"; + case MMOut::kCopy: + return "Copy"; + case MMOut::kParM: + return "ParM"; + default: + return nullptr; + } +} + +// How to parallelize the per-package `DecompressA`. To reduce combinatorial +// explosion, we tune this separately from `MMConfig`. +enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM }; + +static inline const char* StringFromParA(MMParA par_a) { + switch (par_a) { + case MMParA::kNone: + return "ParA0 "; + case MMParA::kK1: + return "ParAK1"; + case MMParA::kK2: + return "ParAK2"; + case MMParA::kK4: + return "ParAK4"; + case MMParA::kM: + return "ParAM "; + default: + return nullptr; + } +} + +// Possible configurations for the autotuner to choose from: +// `mr` := C rows to write at a time (< #registers / `kNR`), +// `kc` := A / B columns such that `mr` rows fit in L1, +// `mc` := A rows such that `kc` columns fit in L2, +// `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C. +// Also includes loop order and task granularity. +#pragma pack(push, 1) +class MMConfig { + public: + MMConfig() = default; // for std::vector + // `mr` is the number of A rows per call to `MMKernel::LoopOverKC`. + // `MMOrder` is how to parallelize the outer loops. + // `MMOut` is how/whether to parallelize filling the C result. + // `inner_tasks` chooses the within-cluster task granularity in `ForNP`. + MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc, + size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out, + int inner_tasks) + : mr_(static_cast(mr)), + mc_(static_cast(mc)), + kc_(static_cast(kc)), + nc_(static_cast(nc)), + nc_multiple_(static_cast(nc_multiple)), + kc_multiple_(static_cast(kc_multiple)), + order_(order), + out_(out), + inner_tasks_(static_cast(inner_tasks)), + reserved_{} { + HWY_DASSERT(mr == 1 || mr == 2 || mr == 4); + if (mc % mr != 0) { + HWY_WARN("mc %zu not a multiple of mr %zu", mc, mr); + } + // Do not warn for single-kc tasks; some models unfortunately have K which + // are not multiples of `kc_multiple`. + if (kc != K && (kc % kc_multiple) != 0) { + HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple); + } + if (nc % nc_multiple != 0) { + HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple); + } + HWY_DASSERT(StringFromOrder(order_) != nullptr); + HWY_DASSERT(StringFromOut(out_) != nullptr); + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + } + + // Splits M/N into blocks which are visited sequentially or in parallel. + // K is always sequential, see `MMOrder`. + IndexRangePartition RangesOfMC(size_t M) const { + return MaxSizePartition(IndexRange(0, M), mc_, mr_); + } + IndexRangePartition RangesOfKC(size_t K) const { + return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_); + } + IndexRangePartition RangesOfNC(IndexRange range_np) const { + return MaxSizePartition(range_np, nc_, nc_multiple_); } - RowVectorBatch& Buf() { return buf_; } + MMOrder Order() const { return order_; } + MMOut Out() const { return out_; } + // No `OuterTasks` because static partitioning across clusters is sufficient. + size_t InnerTasks() const { return static_cast(inner_tasks_); } + + // Accessors for printing autotune result. + size_t MR() const { return static_cast(mr_); } + size_t MC() const { return static_cast(mc_); } + size_t KC() const { return static_cast(kc_); } + size_t NC() const { return static_cast(nc_); } + + private: + // Somewhat-compressed representation because MMCandidates may return dozens. + uint32_t mr_; + uint32_t mc_; + uint32_t kc_; + uint32_t nc_; + uint32_t nc_multiple_; + uint32_t kc_multiple_; + MMOrder order_; + MMOut out_; + uint8_t inner_tasks_; + HWY_MAYBE_UNUSED uint8_t reserved_[5]; +}; +static_assert(sizeof(MMConfig) == 32); // for faster indexing +#pragma pack(pop) + +std::vector MMCandidates(size_t M, size_t K, size_t N, size_t max_mr, + size_t nr, + const IndexRangePartition& ranges_np, + bool print_config); + +// State machine for choosing the best `TConfig`, which is `MMConfig` for the +// main MatMul autotuner. +template +class MMAutoTune { + public: + // Returns nullptr if not yet finished, otherwise the best config. Do not + // store this pointer because it can be invalidated. + const TConfig* Best() const { return best_; } + + // If false, caller must call `SetCandidates` before `NextConfig`. + bool HasCandidates() const { + HWY_DASSERT(!Best()); + return !candidates_.empty(); + } + void SetCandidates(std::vector candidates) { + HWY_DASSERT(!HasCandidates()); + candidates_.swap(candidates); + HWY_DASSERT(HasCandidates()); + min_ticks_.resize(candidates_.size(), ~uint64_t{0}); + } + + // Returns the current `TConfig` to measure. + const TConfig& NextConfig() const { + HWY_DASSERT(!Best() && HasCandidates()); + return candidates_[config_idx_]; + } + + // Returns the best ticks so far for this candidate. Negligible CPU time. + uint64_t NotifyTicks(uint64_t ticks) { + HWY_DASSERT(HasCandidates()); + HWY_DASSERT(!skipped_.Get(config_idx_)); + + best_ticks_ = HWY_MIN(best_ticks_, ticks); + min_ticks_[config_idx_] = HWY_MIN(min_ticks_[config_idx_], ticks); + // Best so far. Save because we update `config_idx_` below. + const size_t my_best_ticks = min_ticks_[config_idx_]; + const size_t my_idx = config_idx_; + + // Advance/wrap around to next non-skipped config. Do this first because it + // updates `rounds_complete_`. To decorrelate measurements, we do not + // immediately re-measure the same config. + for (;;) { + ++config_idx_; + if (HWY_UNLIKELY(config_idx_ == candidates_.size())) { + config_idx_ = 0; + ++rounds_complete_; + } + // Guaranteed to terminate because `best_ticks_` is never worse than any + // other, hence is not skipped. + if (!skipped_.Get(config_idx_)) break; + } + + // Disqualify from future `NextConfig` if the best of two measurements so + // far is sufficiently worse than `best_ticks_`. This tolerates some noise + // in the first or second measurement. + if (rounds_complete_ != 0 && my_best_ticks > 5 * best_ticks_ / 4) { + skipped_.Set(my_idx); + } + + // After sufficient rounds, choose the winner. + if (rounds_complete_ == 4) { + for (size_t i = 0; i < candidates_.size(); ++i) { + worst_min_ticks_ = HWY_MAX(worst_min_ticks_, min_ticks_[i]); + if (min_ticks_[i] == best_ticks_) { + // Causes `Best()` to be non-null, hence `MatMul` will no longer call + // `NextConfig` for this shape. + best_ = &candidates_[i]; + config_idx_ = i; // just in case callers want to know which index. + } + } + HWY_DASSERT(best_ != nullptr); // no min_ticks_ matches best_ticks_ + } + + return my_best_ticks; + } + + // Avoid printing the first two rounds, because those might be noisy and not + // yet skipped. + bool ShouldPrint() { return rounds_complete_ > 2; } + + // Only valid after Best() is non-null. Used to compute the autotuning gain. + uint64_t BestTicks() const { return best_ticks_; } + uint64_t WorstMinTicks() const { return worst_min_ticks_; } + uint64_t FirstConfigTicks() const { return min_ticks_[0]; } + + private: + const TConfig* best_ = nullptr; + std::vector candidates_; + // Use Min because threads are pinned, so we only expect additive noise. + std::vector min_ticks_; // one per candidate + size_t config_idx_ = 0; // [0, candidates_.size()) + size_t rounds_complete_ = 0; + uint64_t best_ticks_ = ~uint64_t{0}; + uint64_t worst_min_ticks_ = 0; + hwy::BitSet4096<> skipped_; +}; + +//------------------------------------------------------------------------------ + +// Map of previously seen dimensions to index via linear search. +class MMKeys { + public: + using Key = uint64_t; + // KeyFromDims will only return this if all dims are zero, which is invalid. + static constexpr Key kPadding = 0; + + // Compresses the dimensions into a single Key for faster comparison. + static Key KeyFromDims(size_t M, size_t K, size_t N) { + HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller + HWY_DASSERT(K < (Key{1} << 24)); + HWY_DASSERT(N < (Key{1} << 24)); + const Key key = static_cast(M) | (static_cast(K) << 16) | + (static_cast(N) << 40); + HWY_DASSERT(key != kPadding); + return key; + } + + // We leave the search to callers so they can use dynamic-dispatched SIMD, + // which is not possible in this header. + hwy::Span Keys() const { + return hwy::Span(keys_.get(), num_unique_); + } + + // Must only be called if not already present in `Keys()`. + void Append(Key key) { + // Dynamic allocation because the test checks many more dimensions than + // would be reasonable to pre-allocate. DIY for alignment and padding. + if (HWY_UNLIKELY(num_unique_ >= capacity_)) { + const size_t NU64 = Allocator::VectorBytes() / sizeof(Key); + // Start at one vector so the size is always a multiple of N. + if (HWY_UNLIKELY(capacity_ == 0)) { + capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below + } + capacity_ *= 2; + HWY_DASSERT(capacity_ >= num_unique_ + 1); + hwy::AlignedFreeUniquePtr new_keys = + hwy::AllocateAligned(capacity_); + hwy::CopyBytes(keys_.get(), new_keys.get(), num_unique_ * sizeof(Key)); + // Pad for SIMD. + for (size_t i = num_unique_; i < hwy::RoundUpTo(num_unique_, NU64); ++i) { + new_keys[i] = kPadding; + } + keys_.swap(new_keys); + } + keys_[num_unique_++] = key; + } + + private: + size_t capacity_ = 0; + size_t num_unique_ = 0; + hwy::AlignedFreeUniquePtr keys_; +}; + +// Per-MatMul-shape state. +struct MMPerKey { + MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel) + : ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {} + + // Only profile if enabled and the main autotuner finished (the par_a + // autotuner is per-package and we want to avoid synchronization). + bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); } + + const IndexRangePartition ranges_np; + MMAutoTune autotune; + MMAutoTune autotune_par_a[MMParallel::kMaxPackages]; +}; + +// Stores state shared across MatMul calls. Non-copyable. +struct MatMulEnv { + explicit MatMulEnv(NestedPools& pools); + + bool have_timer_stop = false; + + // Enable binding: disabled in Gemma until tensors support it, enabled in + // bench_matmul.cc. + bool enable_bind = false; + + // Whether `MMCandidates()` should print the set of parameters. + bool print_config = false; + // Whether to print each config's speed during autotuning. + bool print_measurement = false; + // Whether to print the best config immediately after autotuning finished. + bool print_best = false; MMParallel parallel; + MMStorage storage; + MMKeys keys; + std::vector per_key; +}; - // TODO: remove once no longer used. - NestedPools& Pools() const { return parallel.Pools(); } - hwy::ThreadPool& Pool() const { return parallel.Pool(); } +// Arguments to MatMul() that are independent of the A/B type. +// Reduces register pressure compared to individual values/references. +struct MMArgs { + MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, + const float* HWY_RESTRICT add, const RowPtrD& partial, + const RowPtrF& C) + : env(&env), + per_key(&per_key), + scale(scale), + add(add), + partial(partial), + C(C) {} + + MatMulEnv* env; + MMPerKey* per_key; + + double scale; + const float* HWY_RESTRICT add; + // Same size as C, threads write at false-sharing-free granularity. + RowPtrD partial; + RowPtrF C; +}; + +// Wrapper over hwy::Zone that is only enabled when autotuning finished. +#if PROFILER_ENABLED +class MMZone { + using Zone = hwy::Zone; + static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8); + + public: + ~MMZone() { + if (used_) { + Zone* zone = reinterpret_cast(&data_); + zone->~Zone(); + } + } + + // `name` must be a string literal. + void MaybeEnter(const char* name, const MMArgs& args) { + if (args.per_key->WantProfile()) { + new (&data_) Zone(name); + used_ = true; + } + } private: - RowVectorBatch buf_; + uint64_t data_ = 0; + bool used_ = false; }; +#else +struct MMZone { + void MaybeEnter(const char*, const MMArgs&) {} +}; +#endif // PROFILER_ENABLED // Used for the A and B arguments of `MatMul`, which are always const. // Create via MakeConstMat. This differs from `RowPtr` in that it supports the @@ -161,6 +742,29 @@ ConstMat ConstMatFromWeights(const MatPtrT& m, size_t ofs = 0) { return mat; } +template +void BindB(size_t N, const ConstMat& B, MMParallel& parallel) { + if (!Allocator::ShouldBind()) return; + + const IndexRangePartition ranges_np = + parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR); + const size_t quantum = Allocator::QuantumBytes() / sizeof(TB); + for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { + const IndexRange& rows_b = ranges_np.Range(pkg_idx); + const size_t node = parallel.Node(pkg_idx); + uintptr_t begin = + reinterpret_cast(B.ptr + B.Row(rows_b.begin())); + uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB); + // B is not yet guaranteed to have padded rows, so only bind the + // subset that is page-aligned. + begin = hwy::RoundUpTo(begin, quantum); + end = hwy::RoundDownTo(end, quantum); + if (HWY_LIKELY(begin != end)) { + Allocator::BindMemory(reinterpret_cast(begin), end - begin, node); + } + } +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 3dc90b1..ad57508 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -243,17 +243,20 @@ template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env) { hwy::ThreadPool& pool = env.parallel.Pools().Pool(); - fprintf(stderr, "TestMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac, + fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName()); + env.print_config = true; + env.print_best = true; + const Extents2D A_extents(rows_ac, cols_a_rows_b); const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D C_extents(rows_ac, cols_bc); MatStoragePtr a = GenerateMat(A_extents, pool); MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); - RowVectorBatch c_slow_batch(C_extents); - RowVectorBatch c_batch(C_extents); + RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_batch = AllocateAlignedRows(C_extents); HWY_ASSERT(a && b_trans); std::unique_ptr> add_storage; @@ -270,8 +273,12 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const RowPtrF C = RowPtrFromBatch(c_batch); MatMulSlow(A, B, add_row, env, C_slow); - MatMul(A, B, add_row, env, C); - AssertClose(A, B, C_slow, C); + // A few reps to get coverage of the various autotuned code paths. + for (size_t rep = 0; rep < 16; ++rep) { + MMPerKey* per_key = MatMul(A, B, add_row, env, C); + AssertClose(A, B, C_slow, C); + if (per_key->autotune.Best()) break; + } } using F32 = float; @@ -298,13 +305,12 @@ void TestTiny() { Tristate use_spinning = Tristate::kDefault; pools.MaybeStartSpinning(use_spinning); - Allocator::Init(pools.Topology()); + Allocator::Init(pools.Topology(), /*enable_bind=*/true); MatMulEnv env(pools); - for (size_t M = 1; M <= 3 * kRegRows; ++M) { - for (size_t K = 64; K <= 128; K *= 2) { - for (size_t N = /*kRegRows*/ 16; N <= 64; - N += max_packages * kRegRows) { + for (size_t M = 1; M <= 12; ++M) { + for (size_t K = 1; K <= 64; K *= 2) { + for (size_t N = 4; N <= 64; N += max_packages * 4) { TestMatMul(M, K, N, /*add=*/false, env); } } @@ -323,7 +329,7 @@ void TestAllMatMul() { NestedPools pools(0); // no limits Tristate use_spinning = Tristate::kDefault; pools.MaybeStartSpinning(use_spinning); - Allocator::Init(pools.Topology()); + Allocator::Init(pools.Topology(), /*enable_bind=*/true); MatMulEnv env(pools); // Sizes seen in gemma_test 2B. diff --git a/ops/matmul_unit_test.cc b/ops/matmul_unit_test.cc deleted file mode 100644 index f8752b8..0000000 --- a/ops/matmul_unit_test.cc +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// TODO: Tests of individual MatMul components. -int main() { return 0; } \ No newline at end of file diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 8e57373..e6a71ef 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -35,6 +35,7 @@ #include "gemma/common.h" #include "gemma/configs.h" #include "util/allocator.h" +#include "util/app.h" #include "util/test_util.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -386,6 +387,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( } void TestRopeAndMulBy() { + NestedPools pools = CreatePools(AppArgs()); + Allocator::Init(pools.Topology()); + ModelConfig config = ConfigFromModel(Model::GEMMA2_9B); int dim_qkv = config.layer_configs[0].qkv_dim; RowVectorBatch x(Extents2D(1, dim_qkv)); diff --git a/util/allocator.cc b/util/allocator.cc index f87ed50..c3db82f 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -34,6 +34,11 @@ #endif #ifndef GEMMA_BIND // allow override +// OSes will generally do the right thing when threads allocate their own +// working memory. However, matmul's B and C matrices are preferably sharded +// across NUMA nodes. To simplify the matrix representation, we prefer a +// single allocation. This requires page-level control over the memory layout, +// which Linux provides via `move_pages`, but Windows does not. #if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__) #define GEMMA_BIND 1 #else @@ -93,7 +98,7 @@ size_t Allocator::L2Bytes() { return l2_bytes_; } size_t Allocator::L3Bytes() { return l3_bytes_; } bool Allocator::ShouldBind() { return should_bind_; } -void Allocator::Init(const BoundedTopology& topology) { +void Allocator::Init(const BoundedTopology& topology, bool enable_bind) { line_bytes_ = DetectLineBytes(); vector_bytes_ = hwy::VectorBytes(); step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); @@ -122,13 +127,19 @@ void Allocator::Init(const BoundedTopology& topology) { const size_t page_bytes = DetectPageSize(); if ((page_bytes != 0 && page_bytes <= 16 * 1024) && topology.NumNodes() > 1 && topology.NumPackages() > 1) { - // Ensure pages meet the alignment requirements of `AllocBytes`. - HWY_ASSERT(page_bytes >= quantum_bytes_); - quantum_bytes_ = page_bytes; - // Ensure MaxQuantumBytes() is an upper bound. - HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_); - quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes()); - should_bind_ = true; + if (enable_bind) { + // Ensure pages meet the alignment requirements of `AllocBytes`. + HWY_ASSERT(page_bytes >= quantum_bytes_); + quantum_bytes_ = page_bytes; + // Ensure MaxQuantumBytes() is an upper bound. + HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_); + quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes()); + should_bind_ = true; + } else { + HWY_WARN( + "Multiple sockets but binding disabled. This reduces speed; " + "set or remove enable_bind to avoid this warning."); + } } } diff --git a/util/allocator.h b/util/allocator.h index 4007b22..55c26af 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -16,6 +16,8 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ +// Allocator with support for sharding tensors across NUMA nodes. + #include #include @@ -65,7 +67,8 @@ class Allocator { public: // Must be called at least once before any other function. Not thread-safe, // hence only call this from the main thread. - static void Init(const BoundedTopology& topology); + // TODO: remove enable_bind once Gemma tensors support binding. + static void Init(const BoundedTopology& topology, bool enable_bind = false); // Bytes per cache line, or a reasonable guess if unknown. Used to choose // ranges such that there will be no false sharing. @@ -80,8 +83,10 @@ class Allocator { static constexpr size_t MaxQuantumBytes() { return 4096; } static size_t QuantumSteps(); // = QuantumBytes() / StepBytes() + // L1 and L2 are typically per core. static size_t L1Bytes(); static size_t L2Bytes(); + // Clusters often share an L3. We return the total size per package. static size_t L3Bytes(); // Returns pointer aligned to `QuantumBytes()`. @@ -119,6 +124,19 @@ class Allocator { static PtrAndDeleter AllocBytes(size_t bytes); }; +// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets" +// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is +// typically 4KiB. To avoid remote accesses, we would thus pad each row to that, +// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able +// to prevent that by pulling rows forward by a cyclic offset, which is still a +// multiple of the cache line size. This requires an additional +// `Allocator::QuantumBytes()` of padding after also rounding up to that. +template +constexpr size_t StrideForCyclicOffsets(size_t cols) { + const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T); + return hwy::RoundUpTo(cols, quantum) + quantum; +} + // Owns dynamically-allocated aligned memory for a batch of row vectors. // This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns // the memory. @@ -130,6 +148,7 @@ class RowVectorBatch { // Main ctor, called from Activations::Allocate. If `stride` = 0, the default, // we default to tightly packed rows (`stride = cols`). // WARNING: not all call sites support `stride` != cols. + // TODO: once they do, remove stride and behave like AllocateAlignedRows here. RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) { if (stride == 0) { stride_ = extents_.cols; @@ -137,7 +156,10 @@ class RowVectorBatch { HWY_ASSERT(stride >= extents_.cols); stride_ = stride; } - mem_ = Allocator::Alloc(extents_.rows * stride_); + // Allow binding the entire matrix. + const size_t padded = hwy::RoundUpTo(extents_.rows * stride_, + Allocator::QuantumBytes() / sizeof(T)); + mem_ = Allocator::Alloc(padded); } // Move-only @@ -186,6 +208,11 @@ static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) { return padded_num; } +template +RowVectorBatch AllocateAlignedRows(Extents2D extents) { + return RowVectorBatch(extents, StrideForCyclicOffsets(extents.cols)); +} + // Lightweight version of `MatPtr` used for the C argument of `MatMul`, because // it is always float and does not support compressed T, but does support an // arbitrary stride >= cols. @@ -202,7 +229,19 @@ class RowPtr { row_mask_(Allocator::QuantumSteps() - 1) { HWY_DASSERT(stride >= cols); HWY_DASSERT(row_mask_ != ~size_t{0}); - row_mask_ = 0; // TODO: remove + if constexpr (HWY_IS_DEBUG_BUILD) { + if (stride < StrideForCyclicOffsets(cols)) { + static bool once; + if (!once) { + once = true; + HWY_WARN( + "Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), " + "T=%zu; this forces us to disable cyclic offsets.", + stride, cols, sizeof(T)); + } + row_mask_ = 0; + } + } } RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} diff --git a/util/app.h b/util/app.h index d759467..49a75b5 100644 --- a/util/app.h +++ b/util/app.h @@ -295,6 +295,19 @@ struct InferenceArgs : public ArgsBase { runtime_config.max_generated_tokens = max_generated_tokens; runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size; + if (prefill_tbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + prefill_tbatch_size, MMStorage::kMaxM); + } + if (decode_qbatch_size > MMStorage::kMaxM) { + HWY_ABORT( + "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " + "or increase the constant in MMStorage.\n", + decode_qbatch_size, MMStorage::kMaxM); + } + runtime_config.temperature = temperature; runtime_config.top_k = top_k; }