Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimised max_element implementation as a specific case of nth_element with beam = 1 #981

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added
- Optimised special case for nth_element when decoding with beam size of 1.

## [1.12.0] - 2023-02-20

### Added
Expand Down
138 changes: 124 additions & 14 deletions src/translator/nth_element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,109 @@ class NthElementCPU {
NthElementCPU() {}
NthElementCPU(const NthElementCPU& copy) = delete;

// Efficient max_element implementations following https://github.com/XapaJIaMnu/maxelem_test
private:
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4267)
#endif
#if defined(__AVX512F__)
int max_elem(const float * vec, size_t size) {
float maxVal = vec[0];
int max_idx = 0;
div_t setup = div(size, 16);
int overhang = setup.rem;
int seq = setup.quot;
__m512 maxvalVec = _mm512_set1_ps(maxVal);
for (int i = 0; i < seq*16; i+=16) {
auto res = _mm512_cmp_ps_mask(maxvalVec, _mm512_load_ps(&vec[i]), _CMP_LT_OS);
if (res != 0) {
for (int j = 0; j<16; j++) {
if (vec[i+j] > maxVal) {
maxVal = vec[i+j];
max_idx = i+j;
}
}
maxvalVec = _mm512_set1_ps(maxVal);
}
}
// Take care of the overhang
for (int i = seq*16; i < seq*16 + overhang; i++) {
if (maxVal < vec[i]) {
max_idx = i;
maxVal = vec[i];
}
}
return max_idx;
}
#elif defined(__AVX__)
int max_elem(const float * vec, size_t size) {
float maxVal = vec[0];
int max_idx = 0;
div_t setup = div(size, 8);
int overhang = setup.rem;
int seq = setup.quot;
__m256 maxvalVec = _mm256_set1_ps(maxVal);
for (int i = 0; i < seq*8; i+=8) {
__m256 res = _mm256_cmp_ps(maxvalVec, _mm256_load_ps(&vec[i]), _CMP_LT_OS);
if (_mm256_movemask_ps(res) != 0) {
for (int j = 0; j<8; j++) {
if (vec[i+j] > maxVal) {
maxVal = vec[i+j];
max_idx = i+j;
}
}
maxvalVec = _mm256_set1_ps(maxVal);
}
}
// Take care of the overhang
for (int i = seq*8; i < seq*8 + overhang; i++) {
if (maxVal < vec[i]) {
max_idx = i;
maxVal = vec[i];
}
}
return max_idx;
}
#elif defined(__SSE__)
int max_elem(const float * vec, size_t size) {
float maxVal = vec[0];
int max_idx = 0;
div_t setup = div(size, 4);
int overhang = setup.rem;
int seq = setup.quot;
__m128 maxvalVec = _mm_set1_ps(maxVal);
for (int i = 0; i < seq*4; i+=4) {
__m128 res = _mm_cmplt_ps(maxvalVec, _mm_load_ps(&vec[i]));
// We might have more than one increased matches, so not sure if this can be further optimised
if (_mm_movemask_ps(res) != 0) {
for (int j = 0; j<4; j++) {
if (vec[i+j] > maxVal) {
maxVal = vec[i+j];
max_idx = i+j;
}
}
maxvalVec = _mm_set1_ps(maxVal);
}
}
// Take care of the overhang
for (int i = seq*4; i < seq*4 + overhang; i++) {
if (maxVal < vec[i]) {
max_idx = i;
maxVal = vec[i];
}
}
return max_idx;
}
#else
int max_elem(const float * vec, size_t size) {
auto elem = std::max_element(vec, vec + size);
return std::distance(vec, elem);
}
#endif
#ifdef _MSC_VER
#pragma warning( pop )
#endif

public:
void getNBestList(Tensor scores, // [dimBatch, 1, beamSize, dimVocab or dimShortlist]
Expand All @@ -45,22 +148,29 @@ class NthElementCPU {
std::iota(idxs.begin(), idxs.end(), 0);

for(size_t batchIdx = 0; batchIdx < dimBatch; ++batchIdx) {

std::partial_sort(
// sorts the top N (beam size) idxs by score to the front
idxs.begin(),
idxs.begin() + N,
idxs.end(),
[&](int a, int b) { return scoresData[a] > scoresData[b]; }
);

// copy top N idxs and scores to return vectors
for(size_t i = 0; i < N; ++i) {
int idx = idxs[i];
if (N == 1) { // Special case for beam size, which is much faster
int idx = max_elem(scoresData, batchOffset); // n-th vocabulary item is the vocabulary id
// since idxs is re-used for each batch, add batch offset to each idx to get absolute position
h_res_idx[pos] = (int) (idx + batchIdx * batchOffset);
h_res[pos] = scoresData[idx];
h_res_idx[pos] = (int)(idx + batchIdx * batchOffset);
h_res[pos] = scoresData[idx];
++pos;
} else {
std::partial_sort(
// sorts the top N (beam size) idxs by score to the front
idxs.begin(),
idxs.begin() + N,
idxs.end(),
[&](int a, int b) { return scoresData[a] > scoresData[b]; }
);

// copy top N idxs and scores to return vectors
for(size_t i = 0; i < N; ++i) {
int idx = idxs[i];
// since idxs is re-used for each batch, add batch offset to each idx to get absolute position
h_res_idx[pos] = (int) (idx + batchIdx * batchOffset);
h_res[pos] = scoresData[idx];
++pos;
}
}

// advance pointer to next batch's beginning
Expand Down