Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small optimizations #768

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Uses the per thread default stream for cublas
- Uses the strided batched gemm cublas call when possible for the batchedGemm.
- In the general batchedGemm case, reduces the number of memcpy calls from 3 to 1.
- Rounds the width of input batches to a multiple of 8 when the GPU backend is being used. This is to enable better use of tensorcores on Volta architectures and newer.
- Places NVIDIA notices to some files
- Add --train-embedder-rank for fine-tuning any encoder(-decoder) model for multi-lingual similarity via softmax-margin loss
- Add --logical-epoch that allows to redefine the displayed epoch counter as a multiple of n data epochs, updates or labels. Also allows to define width of fractional part with second argument.
- Add --metrics chrf for computing ChrF according to https://www.aclweb.org/anthology/W15-3049/ and SacreBLEU reference implementation
Expand Down
13 changes: 13 additions & 0 deletions src/data/corpus.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? It seems the only change is the rounding of maxDims. What NVidia contribution was made here?

In general, I would be opposed to changing the comment style for licence. Marian source code does not have license information in the source files directly, but rather in a separate license file. Please let's continue to follow that pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that was the only contribution. However, I was told to include a notice even in files where I make one line changes. This is just me following instructions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I just saw the second part of your comment. Is there a process for adding NVIDIA to the license file? I'm not sure what the solution is here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one that hasn't been updated since 2016 and doesn't even name Microsoft? I guess add it to your PR and shame Marcin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kpu I think this is the one Frank was referring to. I'll ask Marcin if this is ok. I will also need to check to see if we are internally ok with removing the notices assuming we are added to the license file.

@emjotde What do you suggest as the way forward?

Copy link
Contributor Author

@rhenry-nv rhenry-nv Dec 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked internally and I can add NVIDIA to the main license file and remove the notices in all the other files!

I will take care of this in all the PRs I have submitted.

Edit: Let me know if the license change looks ok.

*/

#include "data/corpus.h"

#include <numeric>
Expand Down Expand Up @@ -266,6 +271,14 @@ CorpusBase::batch_ptr Corpus::toBatch(const std::vector<Sample>& batchVector) {
}
rhenry-nv marked this conversation as resolved.
Show resolved Hide resolved
sentenceIds.push_back(ex.getId());
}

// When running on GPU, we want the batchWidth to be a multiple of 8 for better tensorcore usage
if(options_->get<int>("cpu-threads") == 0) {
constexpr int roundingFactor = 8;
for(size_t j = 0; j < maxDims.size(); ++j) {
maxDims[j] = roundingFactor * ((maxDims[j] + roundingFactor - 1) / roundingFactor);
}
}

std::vector<Ptr<SubBatch>> subBatches;
for(size_t j = 0; j < maxDims.size(); ++j) {
Expand Down
12 changes: 10 additions & 2 deletions src/layers/generic.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@

/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#include "marian.h"

#include "layers/generic.h"
Expand Down Expand Up @@ -88,8 +94,10 @@ namespace marian {
}

// if selIdx are given, then we must reshuffle accordingly
if (!hypIndices.empty()) // use the same function that shuffles decoder state
sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
if (!hypIndices.empty()) { // use the same function that shuffles decoder state
auto indices = graph()->indices(hypIndices);
sel = rnn::State::select(sel, indices, (int)beamSize, /*isBatchMajor=*/false);
}
return sel;
}

Expand Down
8 changes: 7 additions & 1 deletion src/models/states.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "marian.h"
Expand Down Expand Up @@ -30,7 +35,8 @@ class EncoderState {
// Sub-select active batch entries from encoder context and context mask
Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
// Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
auto indices = context_->graph()->indices(batchIndices);
return New<EncoderState>(index_select(context_, -2, indices), index_select(mask_, -2, indices), batch_);
}
};

Expand Down
32 changes: 27 additions & 5 deletions src/rnn/types.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "common/definitions.h"
#include "common/shape.h"
#include "marian.h"

#include <iostream>
Expand All @@ -12,23 +19,22 @@ struct State {
Expr output;
Expr cell;

State select(const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
State select(Expr selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
return{ select(output, selIdx, beamSize, isBatchMajor),
select(cell, selIdx, beamSize, isBatchMajor) };
}

// this function is also called by Logits
static Expr select(Expr sel, // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN)
const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
Expr selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor)
{
if (!sel)
return sel; // keep nullptr untouched

sel = atleast_4d(sel);

int dimBatch = (int)selIdx.size() / beamSize;
int dimBatch =(int) selIdx->shape().elements()/beamSize;
int dimDepth = sel->shape()[-1];
int dimTime = isBatchMajor ? sel->shape()[-2] : sel->shape()[-3];

Expand Down Expand Up @@ -83,8 +89,24 @@ class States {
States select(const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor) const {
States selected;
Expr indices;
// I think this doesn't work if model split among gpus but not sure if it matters

for (auto& state : states_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment what this logic does, as it is not obvious why the old code is not working.

Copy link
Contributor Author

@rhenry-nv rhenry-nv Dec 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nothing wrong with the old code. This just needs to check if we need to ship indices to the GPU. I will add a comment explaining what this does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment in a recent commit but I'm not sure if it's clear enough. If so, feel free to resolve this.

if (state.cell) {
indices = state.cell->graph()->indices(selIdx);
break;
}

if (state.output) {
indices = state.output->graph()->indices(selIdx);
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if neither? Is that a valid condition? If not, let's change this to else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think neither is a valid condition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, because it's a loop, sorry. But what happens if it never matches any of the conditions? Then indices ends up being NULL. Is it worth to ABORT_IF for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think indices can only end up being NULL if all of the states' output and cell fields are null. In that case, we will return a vector of nulls which was the behavior of the original code. (Also, I think these values are NULL on the first run of a network so we want this behavior).

}

// GPU OPT: Implement kernel to batch these on GPU
for(auto& state : states_)
selected.push_back(state.select(selIdx, beamSize, isBatchMajor));
selected.push_back(state.select(indices, beamSize, isBatchMajor));
return selected;
}

Expand Down
7 changes: 7 additions & 0 deletions src/tensors/gpu/backend.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "common/config.h"
Expand Down Expand Up @@ -52,6 +57,7 @@ class Backend : public marian::Backend {
if(!cublasHandle_) { // lazy initialization here to avoid memory usage when unused
setDevice();
cublasCreate(&cublasHandle_);
cublasSetStream(cublasHandle_, cudaStreamPerThread);
}
return cublasHandle_;
}
Expand All @@ -60,6 +66,7 @@ class Backend : public marian::Backend {
if(!cusparseHandle_) { // lazy initialization here to avoid memory usage when unused
setDevice();
cusparseCreate(&cusparseHandle_);
cusparseSetStream(cusparseHandle_, cudaStreamPerThread);
}
return cusparseHandle_;
}
Expand Down
163 changes: 121 additions & 42 deletions src/tensors/gpu/prod.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#ifdef _MSC_VER
#pragma warning(disable: 4505) // warning C4505: '__float2half_rz': unreferenced local function has been removed (missing 'static inline')
#endif
Expand Down Expand Up @@ -247,6 +251,63 @@ cublasStatus_t cublasGemmBatchedTyped(cublasHandle_t handle,
}
#endif

cublasStatus_t cublasGemmBatchedStridedTyped(cublasHandle_t handle,
CudaCompute computeCapability,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda, int strideA,
const float *B, int ldb, int strideB,
const float *beta,
float *C, int ldc, int strideC,
int batchCount) {
// double #if and if unfortunately required to safeguard against compilation error
// with CUDA 8.0 and runtime error with CUDA >9.0 on GPUs with compute capability under 5
#if CUDA_VERSION > 9000
// query math mode and set algorithm accordingly
auto algorithm = tensorOpsEnabled(handle) ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
if(computeCapability.major >= 5)
return cublasGemmStridedBatchedEx(handle, transa, transb,
m, n, k, alpha,
(void const*)A, CUDA_R_32F, lda, strideA,
(void const*)B, CUDA_R_32F, ldb, strideB, beta,
(void*)C, CUDA_R_32F, ldc, strideC, batchCount,
CUDA_R_32F, algorithm);
#endif
return cublasSgemmStridedBatched(handle, transa, transb,
m, n, k, alpha,
A, lda, strideA,
B, ldb, strideB,
beta,
C, ldc, strideC,
batchCount);
}

#if COMPILE_FP16 // should not be visible for CUDA 9.0 and below
cublasStatus_t cublasGemmBatchedStridedTyped(cublasHandle_t handle,
CudaCompute computeCapability,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const half *alpha,
const half *A, int lda, int strideA,
const half *B, int ldb, int strideB,
const half *beta,
half *C, int ldc, int strideC,
int batchCount) {
ABORT_IF(computeCapability.major < 6, "Compute capability {} below 6 should not happen for FP16", computeCapability.major);
// query math mode and set algorithm accordingly
auto algorithm = tensorOpsEnabled(handle) ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
return cublasGemmStridedBatchedEx(handle, transa, transb,
m, n, k, alpha,
(void const*)A, CUDA_R_16F, lda, strideA,
(void const*)B, CUDA_R_16F, ldb, strideB, beta,
(void*)C, CUDA_R_16F, ldc, strideC, batchCount,
CUDA_R_16F, algorithm);
}
#endif

template <typename T>
void ProdBatchedTyped(marian::Tensor C,
Ptr<Allocator> allocator,
Expand Down Expand Up @@ -289,50 +350,68 @@ void ProdBatchedTyped(marian::Tensor C,
auto strideA = batchA == 1 ? 0 : m * k;
auto strideB = batchB == 1 ? 0 : n * k;
auto strideC = n * m;
auto batchC = std::max(batchA, batchB);

std::vector<const T*> aptr;
std::vector<const T*> bptr;
std::vector<T*> cptr;
if(batchA == batchB) {
setTensorMode(cublasHandle);
CUBLAS_CHECK(cublasGemmBatchedStridedTyped(cublasHandle,
compute,
opB,
opA,
n,
m,
k,
&alpha,
B->data<const T>(),
ldb, strideB,
A->data<const T>(),
lda, strideA,
&beta,
C->data<T>(),
ldc, strideC,
batchA));
unsetTensorMode(cublasHandle);
} else {
auto batchC = std::max(batchA, batchB);
rhenry-nv marked this conversation as resolved.
Show resolved Hide resolved
size_t size = 3*batchC;
std::vector<T*> ptrs(size);
auto aStart = 0;
auto bStart = batchC;
auto cStart = bStart + batchC;

for(int i = 0; i < batchC; i++) {
ptrs[aStart + i] = A->data<T>() + (i % batchA) * strideA;
ptrs[bStart + i] = B->data<T>() + (i % batchB) * strideB;
ptrs[cStart + i] = C->data<T>() + i * strideC;
}

for(int i = 0; i < batchC; i++) {
aptr.push_back(A->data<T>() + (i % batchA) * strideA);
bptr.push_back(B->data<T>() + (i % batchB) * strideB);
cptr.push_back(C->data<T>() + i * strideC);
// auto fails here from weird reason
IPtr<MemoryPiece> mp_ptrs = allocator->alloc<T*>(size);
T** dest = mp_ptrs->data<T*>();
cudaStream_t cublasStream = 0;
CUBLAS_CHECK(cublasGetStream(cublasHandle, &cublasStream));
CUDA_CHECK(cudaMemcpyAsync(dest, ptrs.data(), size * sizeof(T*), cudaMemcpyHostToDevice, cublasStream));

setTensorMode(cublasHandle);
CUBLAS_CHECK(cublasGemmBatchedTyped(cublasHandle,
compute,
opB,
opA,
n,
m,
k,
&alpha,
mp_ptrs->data<const T*>() + bStart,
ldb,
mp_ptrs->data<const T*>() + aStart,
lda,
&beta,
mp_ptrs->data<T*>() + cStart,
ldc,
batchC));
unsetTensorMode(cublasHandle);

allocator->free(mp_ptrs);
}

// auto fails here from weird reason
IPtr<MemoryPiece> mp_aptr = allocator->alloc<const T*>(aptr.size());
CudaCopy(aptr.data(), aptr.data() + aptr.size(), mp_aptr->data<const T*>());

IPtr<MemoryPiece> mp_bptr = allocator->alloc<const T*>(bptr.size());
CudaCopy(bptr.data(), bptr.data() + bptr.size(), mp_bptr->data<const T*>());

IPtr<MemoryPiece> mp_cptr = allocator->alloc<T*>(cptr.size());
CudaCopy(cptr.data(), cptr.data() + cptr.size(), mp_cptr->data<T*>());

setTensorMode(cublasHandle);
CUBLAS_CHECK(cublasGemmBatchedTyped(cublasHandle,
compute,
opB,
opA,
n,
m,
k,
&alpha,
mp_bptr->data<const T*>(),
ldb,
mp_aptr->data<const T*>(),
lda,
&beta,
mp_cptr->data<T*>(),
ldc,
batchC));
unsetTensorMode(cublasHandle);

allocator->free(mp_aptr);
allocator->free(mp_bptr);
allocator->free(mp_cptr);
}

void ProdBatched(marian::Tensor C,
Expand Down