diff --git a/BUILD.bazel b/BUILD.bazel index 0cf1801..2c816a5 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -108,6 +108,7 @@ cc_library( "@highway//:matvec", "@highway//:profiler", "@highway//:thread_pool", + "@highway//hwy/contrib/sort:vqsort", ], ) diff --git a/ops/ops-inl.h b/ops/ops-inl.h index aad636b..a207381 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include // std::enable_if_t #include @@ -30,6 +30,8 @@ #include "compression/compress.h" #include "util/basics.h" // TokenAndProb #include "hwy/base.h" +#include "hwy/contrib/sort/order.h" +#include "hwy/contrib/sort/vqsort.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" #include "hwy/profiler.h" @@ -54,6 +56,35 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +HWY_INLINE double PackTokenAndProb(int32_t token, float prob) { + // casting prob from float to double just makes some changes to the + // exponent bias and pads zeros in the mantissa. + double packed = static_cast(prob); + int64_t packed_int64; + hwy::CopySameSize(&packed, &packed_int64); + // stuff the token into the lower 32 bits of packed_int64. (it is an int32_t + // anyway) + packed_int64 &= 0xFFFFFFFF00000000; + packed_int64 |= token; + // copy bytes back into packed. + hwy::CopySameSize(&packed_int64, &packed); + return packed; +} + +HWY_INLINE TokenAndProb UnpackTokenAndProb(double packed) { + TokenAndProb tp; + + int64_t packed_int64; + hwy::CopySameSize(&packed, &packed_int64); + tp.token = static_cast(packed_int64 & 0xFFFFFFFFULL); + + // clear the lower 32 bits of packed_int64 before copying back into packed. + packed_int64 &= 0xFFFFFFFF00000000ULL; + hwy::CopySameSize(&packed_int64, &packed); + tp.prob = static_cast(packed); + return tp; +} + template HWY_INLINE constexpr std::enable_if_t< std::is_arithmetic_v && std::is_arithmetic_v, To> @@ -705,37 +736,44 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution( } template -HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( - const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token) { +HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( + const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k, + TAcceptToken& accept_token) { HWY_ASSERT(k != 0); HWY_ASSERT(k <= vocab_size); - // TODO: Optimize, potentially using new VQSort PartialSort. - // Sorted from highest [0], to lowest [k-1] - std::vector top_k(k, -std::numeric_limits::infinity()); - std::vector indices(k); - size_t num_accepted = 0; - for (size_t i = 0; i < vocab_size; ++i) { - if (probabilities[i] < top_k[k - 1]) continue; - bool accepted = - !accept_token || accept_token(StaticCast(i), probabilities[i]); - if (!accepted) continue; - num_accepted++; - for (size_t j = 0; j < k; ++j) { - if (probabilities[i] > top_k[j]) { - // shift elements by 1, insert the new value, move on to next value - for (size_t idx = k - 1; idx > j; --idx) { - top_k[idx] = top_k[idx - 1]; - indices[idx] = indices[idx - 1]; - } - top_k[j] = probabilities[i]; - indices[j] = StaticCast(i); - break; - } + std::vector packed_token_probs; + for (int32_t i = 0; i < vocab_size; ++i) { + if (accept_token && !accept_token(StaticCast(i), probabilities[i])) { + continue; } + packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i])); + } + + hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k, + hwy::SortDescending()); + hwy::VQSort(packed_token_probs.data(), k, hwy::SortDescending()); + + std::vector token_probs; + token_probs.reserve(k); + for (int32_t i = 0; i < k; ++i) { + token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i])); } - HWY_ASSERT(k <= num_accepted); - return indices[create_distribution(top_k, temperature)(gen)]; + return token_probs; +} + +template +HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( + const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size, + std::mt19937& gen, float temperature, TAcceptToken& accept_token) { + std::vector token_probs = + TopK(probabilities, vocab_size, k, accept_token); + std::vector topk_indices(k); + std::vector topk_probs(k); + for (int i = 0; i < k; ++i) { + topk_indices[i] = token_probs[i].token; + topk_probs[i] = token_probs[i].prob; + } + return topk_indices[create_distribution(topk_probs, temperature)(gen)]; } template @@ -745,40 +783,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( // Softmax and sample top-K is equivalent to taking the top-K logits and // sampling from the softmax of the top-K logits. The latter is faster as it // avoids computing the softmax of all logits. - HWY_ASSERT(k != 0); - HWY_ASSERT(k <= vocab_size); - - std::vector top_k(k, -std::numeric_limits::infinity()); - std::vector indices(k); - size_t num_accepted = 0; - for (size_t i = 0; i < vocab_size; ++i) { - if (logits[i] < top_k[k - 1]) continue; - bool accepted = - !accept_token || accept_token(StaticCast(i), logits[i]); - if (!accepted) continue; - num_accepted++; - for (size_t j = 0; j < k; ++j) { - if (logits[i] > top_k[j]) { - // shift elements by 1, insert the new value, move on to next value - for (size_t idx = k - 1; idx > j; --idx) { - top_k[idx] = top_k[idx - 1]; - indices[idx] = indices[idx - 1]; - } - top_k[j] = logits[i]; - indices[j] = StaticCast(i); - break; - } - } + std::vector token_logits = + TopK(logits, vocab_size, k, accept_token); + std::vector topk_indices(k); + std::vector topk_logits(k); + for (int i = 0; i < token_logits.size(); ++i) { + topk_indices[i] = token_logits[i].token; + topk_logits[i] = token_logits[i].prob; } - size_t mask = k <= num_accepted ? k : num_accepted; - Softmax(top_k.data(), mask, temperature); - auto distribution = std::discrete_distribution(std::begin(top_k), - std::begin(top_k) + mask); + size_t mask = token_logits.size(); + Softmax(topk_logits.data(), mask, temperature); + auto distribution = std::discrete_distribution( + std::begin(topk_logits), std::begin(topk_logits) + mask); int topk_sampled_index = distribution(gen); - int sampled_index = indices[topk_sampled_index]; + int sampled_index = topk_indices[topk_sampled_index]; return TokenAndProb{.token = sampled_index, - .prob = top_k[topk_sampled_index]}; + .prob = topk_logits[topk_sampled_index]}; } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 8e57373..b795a44 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -600,6 +600,17 @@ void TestSampleTopK() { } } +void TestPackTokenAndProb() { + double packed1 = PackTokenAndProb(10, 0.96f); + TokenAndProb unpacked1 = UnpackTokenAndProb(packed1); + EXPECT_EQ(unpacked1.token, 10); + EXPECT_NEAR(unpacked1.prob, 0.96f, 1e-6); + + double packed2 = PackTokenAndProb(1000000000, 0.87f); + + EXPECT_LT(packed2, packed1); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -621,6 +632,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple); HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK); +HWY_EXPORT_AND_TEST_P(OpsTest, TestPackTokenAndProb); HWY_AFTER_TEST(); } // namespace gcpp diff --git a/util/basics.h b/util/basics.h index c296934..b8f2735 100644 --- a/util/basics.h +++ b/util/basics.h @@ -57,10 +57,12 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) { } // Shared between gemma.h and ops-inl.h. +#pragma pack(push, 1) struct TokenAndProb { int token; float prob; }; +#pragma pack(pop) // Entire size of a 2D array. struct Extents2D {