Skip to content

Commit

Permalink
Split loops to work around loop vectorizer weakness (#3406)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3406

X-link: facebookresearch/FBGEMM#494

I noticed LLVM loop vectorization making bad calls for loops with trip-counts that are not a multiple of the preferred vector size.

This can be worked around by manually splitting into two loops with the first one processing in multiples of the preferred block size and a separate loop for the remainder.

Reviewed By: helloguo

Differential Revision: D65574038

fbshipit-source-id: 491c63954b0ceb49eb65c4afb6f1a93706dfa525
  • Loading branch information
MatzeB authored and facebook-github-bot committed Nov 22, 2024
1 parent c952277 commit f110630
Showing 1 changed file with 125 additions and 27 deletions.
152 changes: 125 additions & 27 deletions src/EmbeddingSpMDMAutovec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@
#define do_prefetch(...) __builtin_prefetch(__VA_ARGS__)
#endif

#ifdef __clang__
// https://github.com/llvm/llvm-project/issues/114891 / T206675074
// Work around LLVM loop vectorization not produce optimal code when
// `block_size` is not a multiple of the natural vector size.
#ifdef __AVX512F__
#define FBGEMM_VECTOR_WIDTH 16
#elif __AVX2__
#define FBGEMM_VECTOR_WIDTH 8
#elif __SSE__
#define FBGEMM_VECTOR_WIDTH 4
#endif
#endif // #ifdef __clang__

namespace fbgemm {

static constexpr size_t LOCAL_STORAGE_SIZE = 512;
Expand Down Expand Up @@ -142,7 +155,14 @@ static bool ALWAYS_INLINE EmbeddingSpMDM8Bit_autovec(
}

const uint8_t* input_row = input_row_base + input_offset;
for (int64_t j = 0; j < block_size; ++j) {
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
uint8_t value = input_row[j];
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
}
#endif
for (; j < block_size; ++j) {
uint8_t value = input_row[j];
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
}
Expand Down Expand Up @@ -210,16 +230,16 @@ static bool ALWAYS_INLINE EmbeddingSpMDM8Bit_autovec(
}

const uint8_t* input_row = input_row_base + input_offset;
if (block_size <= 64) {
for (int64_t j = 0; j < block_size; ++j) {
uint8_t value = input_row[j];
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
}
} else {
for (int64_t j = 0; j < block_size; ++j) {
uint8_t value = input_row[j];
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
}
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
uint8_t value = input_row[j];
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
}
#endif
for (; j < block_size; ++j) {
uint8_t value = input_row[j];
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
}
}
if (normalize_by_lengths && len) {
Expand Down Expand Up @@ -358,16 +378,42 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
}

if (input_bit_rate == 4) {
for (int64_t j = 0, k = 0; j < block_size; j += 2) {
uint8_t tmp = input_row[k++];
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2));
j += 2) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0xf);
float quantized2 = float(tmp >> 4);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
}
#endif
for (; j < block_size; j += 2) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0xf);
float quantized2 = float(tmp >> 4);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
}
} else if (input_bit_rate == 2) {
for (int64_t j = 0, k = 0; j < block_size; j += 4) {
uint8_t tmp = input_row[k++];
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4));
j += 4) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0x3);
float quantized2 = float((tmp & 0xC) >> 2);
float quantized3 = float((tmp & 0x30) >> 4);
float quantized4 = float(tmp >> 6);
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
}
#endif
for (; j < block_size; j += 4) {
uint8_t tmp = *input_row++;
float quantized1 = float(tmp & 0x3);
float quantized2 = float((tmp & 0xC) >> 2);
float quantized3 = float((tmp & 0x30) >> 4);
Expand Down Expand Up @@ -481,13 +527,28 @@ static bool ALWAYS_INLINE EmbeddingSpMDM_autovec(

if (weights != nullptr) {
float weight = weights[m];
for (int64_t j = 0; j < block_size; ++j) {
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
const InType* inptr = input + input_stride * idx + j;
buf[j] = std::fma(
weight, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
}
#endif
for (; j < block_size; ++j) {
const InType* inptr = input + input_stride * idx + j;
buf[j] = std::fma(
weight, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
}
} else {
for (int64_t j = 0; j < block_size; ++j) {
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
const InType* inptr = input + input_stride * idx + j;
buf[j] += convert_to_float_ref(*inptr, is_bf16_in);
}
#endif
for (; j < block_size; ++j) {
const InType* inptr = input + input_stride * idx + j;
buf[j] += convert_to_float_ref(*inptr, is_bf16_in);
}
Expand Down Expand Up @@ -559,9 +620,17 @@ static bool ALWAYS_INLINE EmbeddingSpMDM_autovec(
w = weights[is_weight_positional ? i : current];
}

for (int64_t j = 0; j < block_size; ++j) {
const InType* inptr = input + input_stride * idx + j;
buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
const InType* input_row = input + input_stride * idx;
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
InType value = *input_row++;
buf[j] = std::fma(w, convert_to_float_ref(value, is_bf16_in), buf[j]);
}
#endif
for (; j < block_size; ++j) {
InType value = *input_row++;
buf[j] = std::fma(w, convert_to_float_ref(value, is_bf16_in), buf[j]);
}

++current;
Expand Down Expand Up @@ -642,9 +711,17 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
bias *= weight;
}

for (int j = 0; j < block_size; ++j) {
out[j] =
std::fma(scale, input[fused_block_size * idx + j], out[j] + bias);
const InType* input_row = input + fused_block_size * idx;
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
InType value = *input_row++;
out[j] = std::fma(scale, value, out[j] + bias);
}
#endif
for (; j < block_size; ++j) {
InType value = *input_row++;
out[j] = std::fma(scale, value, out[j] + bias);
}
}
if (normalize_by_lengths && len) {
Expand Down Expand Up @@ -688,8 +765,20 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
weight = *weights_addr++;
}

for (int j = 0; j < block_size; ++j) {
const InType* inptr = input + block_size * idx + j;
const InType* input_row = input + block_size * idx;
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
const InType* inptr = input_row++;
out[j] = std::fma(
weight,
std::is_same<InType, float16>::value ? cpu_half2float(*inptr)
: *inptr,
out[j]);
}
#endif
for (; j < block_size; ++j) {
const InType* inptr = input_row++;
out[j] = std::fma(
weight,
std::is_same<InType, float16>::value ? cpu_half2float(*inptr)
Expand Down Expand Up @@ -884,8 +973,17 @@ static bool ALWAYS_INLINE EmbeddingSpMDMFP8_autovec(
exponent_bias);

// Now accumulate the results using vectorized operations if possible
for (int j = 0; j < block_size; ++j) {
buf[j] = std::fma(w, converted_inputs[j], buf[j]);
const float* input_row = converted_inputs.get();
int64_t j = 0;
#ifdef FBGEMM_VECTOR_WIDTH
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
float value = *input_row++;
buf[j] = std::fma(w, value, buf[j]);
}
#endif
for (; j < block_size; ++j) {
float value = *input_row++;
buf[j] = std::fma(w, value, buf[j]);
}
}
if (normalize_by_lengths && len) {
Expand Down

0 comments on commit f110630

Please sign in to comment.