Skip to content

Commit

Permalink
Use vectorized TopK using highway VQSelect
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726830079
  • Loading branch information
apoorvreddy authored and copybara-github committed Feb 18, 2025
1 parent 0e5b59d commit aaba127
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 58 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ cc_library(
"@highway//:matvec",
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)

Expand Down
137 changes: 79 additions & 58 deletions ops/ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
#include <stdio.h>

#include <cmath>
#include <limits>
#include <cstdint>
#include <random>
#include <type_traits> // std::enable_if_t
#include <vector>

#include "compression/compress.h"
#include "util/basics.h" // TokenAndProb
#include "hwy/base.h"
#include "hwy/contrib/sort/order.h"
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h"
#include "hwy/profiler.h"
Expand All @@ -54,6 +56,35 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;

HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
// casting prob from float to double just makes some changes to the
// exponent bias and pads zeros in the mantissa.
double packed = static_cast<double>(prob);
int64_t packed_int64;
hwy::CopySameSize(&packed, &packed_int64);
// stuff the token into the lower 32 bits of packed_int64. (it is an int32_t
// anyway)
packed_int64 &= 0xFFFFFFFF00000000;
packed_int64 |= token;
// copy bytes back into packed.
hwy::CopySameSize(&packed_int64, &packed);
return packed;
}

HWY_INLINE TokenAndProb UnpackTokenAndProb(double packed) {
TokenAndProb tp;

int64_t packed_int64;
hwy::CopySameSize(&packed, &packed_int64);
tp.token = static_cast<int>(packed_int64 & 0xFFFFFFFFULL);

// clear the lower 32 bits of packed_int64 before copying back into packed.
packed_int64 &= 0xFFFFFFFF00000000ULL;
hwy::CopySameSize(&packed_int64, &packed);
tp.prob = static_cast<float>(packed);
return tp;
}

template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
Expand Down Expand Up @@ -705,37 +736,44 @@ HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
}

template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
const float* HWY_RESTRICT probabilities, size_t vocab_size, size_t k,
TAcceptToken& accept_token) {
HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size);
// TODO: Optimize, potentially using new VQSort PartialSort.
// Sorted from highest [0], to lowest [k-1]
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
std::vector<int> indices(k);
size_t num_accepted = 0;
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1]) continue;
bool accepted =
!accept_token || accept_token(StaticCast<int>(i), probabilities[i]);
if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j]) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = probabilities[i];
indices[j] = StaticCast<int>(i);
break;
}
std::vector<double> packed_token_probs;
for (int32_t i = 0; i < vocab_size; ++i) {
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) {
continue;
}
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
}

hwy::VQSelect(packed_token_probs.data(), packed_token_probs.size(), k,
hwy::SortDescending());
hwy::VQSort(packed_token_probs.data(), k, hwy::SortDescending());

std::vector<TokenAndProb> token_probs;
token_probs.reserve(k);
for (int32_t i = 0; i < k; ++i) {
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
}
HWY_ASSERT(k <= num_accepted);
return indices[create_distribution(top_k, temperature)(gen)];
return token_probs;
}

template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
std::vector<TokenAndProb> token_probs =
TopK(probabilities, vocab_size, k, accept_token);
std::vector<int> topk_indices(k);
std::vector<float> topk_probs(k);
for (int i = 0; i < k; ++i) {
topk_indices[i] = token_probs[i].token;
topk_probs[i] = token_probs[i].prob;
}
return topk_indices[create_distribution(topk_probs, temperature)(gen)];
}

template <typename TAcceptToken>
Expand All @@ -745,40 +783,23 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits.
HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size);

std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
std::vector<int> indices(k);
size_t num_accepted = 0;
for (size_t i = 0; i < vocab_size; ++i) {
if (logits[i] < top_k[k - 1]) continue;
bool accepted =
!accept_token || accept_token(StaticCast<int>(i), logits[i]);
if (!accepted) continue;
num_accepted++;
for (size_t j = 0; j < k; ++j) {
if (logits[i] > top_k[j]) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = logits[i];
indices[j] = StaticCast<int>(i);
break;
}
}
std::vector<TokenAndProb> token_logits =
TopK(logits, vocab_size, k, accept_token);
std::vector<int> topk_indices(k);
std::vector<float> topk_logits(k);
for (int i = 0; i < token_logits.size(); ++i) {
topk_indices[i] = token_logits[i].token;
topk_logits[i] = token_logits[i].prob;
}

size_t mask = k <= num_accepted ? k : num_accepted;
Softmax(top_k.data(), mask, temperature);
auto distribution = std::discrete_distribution<int>(std::begin(top_k),
std::begin(top_k) + mask);
size_t mask = token_logits.size();
Softmax(topk_logits.data(), mask, temperature);
auto distribution = std::discrete_distribution<int>(
std::begin(topk_logits), std::begin(topk_logits) + mask);
int topk_sampled_index = distribution(gen);
int sampled_index = indices[topk_sampled_index];
int sampled_index = topk_indices[topk_sampled_index];
return TokenAndProb{.token = sampled_index,
.prob = top_k[topk_sampled_index]};
.prob = topk_logits[topk_sampled_index]};
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
12 changes: 12 additions & 0 deletions ops/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,17 @@ void TestSampleTopK() {
}
}

void TestPackTokenAndProb() {
double packed1 = PackTokenAndProb(10, 0.96f);
TokenAndProb unpacked1 = UnpackTokenAndProb(packed1);
EXPECT_EQ(unpacked1.token, 10);
EXPECT_NEAR(unpacked1.prob, 0.96f, 1e-6);

double packed2 = PackTokenAndProb(1000000000, 0.87f);

EXPECT_LT(packed2, packed1);
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
Expand All @@ -621,6 +632,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSampleTopK);
HWY_EXPORT_AND_TEST_P(OpsTest, TestPackTokenAndProb);
HWY_AFTER_TEST();

} // namespace gcpp
Expand Down
2 changes: 2 additions & 0 deletions util/basics.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
}

// Shared between gemma.h and ops-inl.h.
#pragma pack(push, 1)
struct TokenAndProb {
int token;
float prob;
};
#pragma pack(pop)

// Entire size of a 2D array.
struct Extents2D {
Expand Down

0 comments on commit aaba127

Please sign in to comment.