Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds better Affine support for GPUs when using CUDA 11. Introduces a new bias addition kernel for CUDA < 11 #778

Merged
merged 23 commits into from
Apr 9, 2021

Conversation

rhenry-nv
Copy link
Contributor

@rhenry-nv rhenry-nv commented Dec 15, 2020

Description

This is related to issue #680. This adds better affine support for GPU.

List of changes:

  • Adds a fused cublasLt operator when using cuda 11. Prior cuda versions do not support fusing the bias fusion into the GEMM for non int8 types.

Times from a proxy model with and without FP16. This model has a factored vocabulary.

This PR includes features not present in #743. I expect that this will only be more performant than #743 for FP16 using CUDA 11.

PS. I believe the time for batch 1 is lower with FP16 due to a performance issue with cublas 11. It causes high start up cost for stridedBatchedGemms which would impact batch 1 more than larger batch sizes. A work around for this is to bench and specify algos to cublas for different architectures. I have done this for turing (T4) and Volta (Titan V and V100) but did not include this in the PR to keep it smaller. I can add this change in if requested.

Times with one stream

Batch Time from PR #776 (s) Current FP32 time (s) Current time FP16 time(s) Speedup factor (#776 vs fp16)
1 105.638 112.154 88.9725 1.187310686
2 73.2865 70.7748 56.0738 1.306965107
4 45.1652 43.1463 34.2097 1.320245427
8 27.1344 25.6202 20.777 1.305982577
16 16.1808 15.1876 12.5525 1.28904999
32 10.1867 9.37464 7.54261 1.350553721
64 6.45151 5.81335 4.62424 1.395150338
128 4.39468 3.94626 3.10605 1.414877417
256 3.2189 2.83046 2.24654 1.43282559

Times with two streams

Batch Time from PR #776 (s) Current Time(s) Current time FP16 time(s) Speedup factor (#776 vs fp16)
1 94.137 84.9079 70.8766 1.328181657
2 61.736 53.8002 44.717 1.38059351
4 37.2677 32.717 26.8053 1.390310871
8 21.81 19.245 16.3961 1.330194375
16 12.6072 11.2914 9.7066 1.298827602
32 7.70231 6.953 5.79172 1.329883005
64 4.86047 4.36442 3.46255 1.403725578
128 3.3429 2.9891 2.33464 1.431869582
256 2.52556 2.21938 1.68216 1.501379179

Added dependencies: none

How to test

I ran the regression tests and most passed with CUDA 11. With CUDA 10, the regression tests pass as expected. I also manually tested on a proxy model with fp16 and the results seem sensible.

CMake command: cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=on -DUSE_SENTENCEPIECE=on -DUSE_STATIC_LIBS=off -DCOMPILE_SERVER=off -DUSE_FBGEMM=on -DCOMPILE_CUDA_SM35=off -DCOMPILE_CUDA_SM50=off -DCOMPILE_CUDA_SM60=off -DCOMPILE_CUDA_SM70=on -DCOMPILE_CUDA_SM75=off -DCOMPILE_TESTS=on

Ubuntu - 18.04.3 LTS
nvcc - 10.1.243
gcc - 7.5.0

Checklist

  • I have tested the code manually
  • I have run regression tests
  • I have read and followed CONTRIBUTING.md
  • I have updated CHANGELOG.md

@rhenry-nv
Copy link
Contributor Author

I will stop submitting PRs for now and focus on addressing your feedback on the individual changes as it comes. The PRs submitted have the majority of the perf benefits of #743 for FP32.

@kpu
Copy link
Member

kpu commented Dec 15, 2020

Looking forward to this one!

@rhenry-nv rhenry-nv changed the title Adds Affine support for GPUs when using CUDA 11. Introduces a new bias addition kernel. Adds better Affine support for GPUs when using CUDA 11. Introduces a new bias addition kernel for CUDA < 11 Dec 15, 2020
@emjotde
Copy link
Member

emjotde commented Mar 22, 2021

FYI: looking into this now.

@emjotde
Copy link
Member

emjotde commented Mar 22, 2021

@rhenry-nv can you make me a contributor in your fork? I am looking at the merge-conflicts here and would like to change a few things on the way.

@emjotde
Copy link
Member

emjotde commented Mar 22, 2021

From what I am seeing here, we could in theory create an affine node that can take any activation function and run a kernel after the multiply, is that correct? This is more general than relu?

@rhenry-nv
Copy link
Contributor Author

rhenry-nv commented Mar 23, 2021

I have resolved merge conflicts on each of these branches a week ago.

Yea, we could make an affine node that takes in any activation. The reason it is RELU specific is because cublasLt only allows relu and bias addition to be fused into the gemm kernel. For other activation functions, it may make sense to use CUTLASS to get the fusion for the generic affine node.

Without the fusion, I don't expect a performance benefit since with multiple kernels, we will still see launch overheads + pay to go to dram to do a small amount of work.

@rhenry-nv
Copy link
Contributor Author

FYI - I have sent an invite. I also have updates to the other branches locally merged with a more recent version of master that I will push. I will take care of any merge conflicts in the current master later this week.

@emjotde
Copy link
Member

emjotde commented Mar 23, 2021

Alright, got the invite and accepted. So relu is in fact special here in this case and it makes sense to single it out. I will take a look into this. Might refactor a few things.

@rhenry-nv
Copy link
Contributor Author

Yes, relu is special in this case. Something more general is definitely a separate PR.

@emjotde
Copy link
Member

emjotde commented Mar 23, 2021

Yes, I am actually partial to doing the opposite. Since affine+relu is special it deserves to be handled more like its own thing rather than trying to cram everything into the original Affine node. Gimme some time to think about that.

@emjotde
Copy link
Member

emjotde commented Mar 23, 2021

Do we know if other frameworks like Pytorch have that fusion? And how they are handling that?

@rhenry-nv
Copy link
Contributor Author

rhenry-nv commented Mar 23, 2021

I am not sure about Pytorch's capability of fusing into GEMMs. From quick googling, it looks like it can JIT fusing several pointwise ops. I'm not sure if torchscript is advanced enough to fuse into gemms.

When I get some more time, I can run a simple pytorch program with nvprof to see if there is any fusion happening. For TensorRT, we do these fusions manually.

AFAIK, the way one would handle fusions with gemms in general is using cutlass.

@emjotde
Copy link
Member

emjotde commented Mar 23, 2021

OK, don't worry about it too much. I was mostly curious if anyone else has explicitly created an affine+relu operator.

@emjotde
Copy link
Member

emjotde commented Mar 24, 2021

That would also be justified because if I am not wrong one can actually do a relatively simple gradient computation for that fused operator because of the properties of relu specifically.

@emjotde
Copy link
Member

emjotde commented Mar 25, 2021

Submitted a PR to your PR in your fork. Didn't touch the cuda code itself, just some reorganization of the API code.

@rhenry-nv
Copy link
Contributor Author

Sounds good. I should have some time tomorrow to check performance.

@rhenry-nv
Copy link
Contributor Author

rhenry-nv commented Mar 27, 2021

There seems to be a slowdown in the GPU backend compared to the first time I ran. I am not sure if it coming from a change in master or this PR. I don't think it's this PR since I checked perf some time ago. I will profile the most recent master branch next week.

@emjotde
Copy link
Member

emjotde commented Mar 27, 2021

OK thanks. There was a ton of changes, so it's possible.

@rhenry-nv
Copy link
Contributor Author

Looks like the slowdown was introduced in the current master. At least for batch 64, on a proxy model, an older version of master (commit 467b15e) takes 11.44s to translate an entire dataset. The current commit takes 14.8s on the same dataset.

Trying to figure this out now.

@rhenry-nv
Copy link
Contributor Author

@emjotde Your PR looks good compared to the current master, so I merged it into this PR to unblock it for now.

I profiled the GPU backend, and the time spent in forward looks similar (kernel time+ cpu time to launch kernels etc). However, the time between calls to forward in beam search looks to have doubled. The beam search code has not changed much from the 'good' commit. Do you know if the way the graph is constructed change much recently? This is my first guess, though unlikely.

@emjotde
Copy link
Member

emjotde commented Apr 3, 2021 via email

@rhenry-nv
Copy link
Contributor Author

rhenry-nv commented Apr 7, 2021

@emjotde Yep, it was TCMalloc. Master took 11.1s to decode the dataset with TCMalloc used. It was previously around 14.8s. Is there a reason you turned off TCMalloc?

EDIT: Ah, I see Commit 096c48e removed it. I think this demonstrates TCMalloc is still useful.

@emjotde
Copy link
Member

emjotde commented Apr 8, 2021

Yes, reverted PR #840. Will take another look at that particular problem later this week.

@emjotde emjotde merged commit fddd0e0 into marian-nmt:master Apr 9, 2021
@emjotde
Copy link
Member

emjotde commented Apr 9, 2021

Merged. Thanks for your patience. Which one next?

@rhenry-nv
Copy link
Contributor Author

Thanks! Let's do #768 then #770. I think these should be the least 'controversial'. Then #771. I expect #776 to require the most changes/redesign. But we get alot of perf improvement from #776 for models with factored vocabs.

@emjotde
Copy link
Member

emjotde commented Apr 9, 2021

Great, over to #768 then.

@emjotde
Copy link
Member

emjotde commented May 11, 2021

Hi @rhenry-nv
We have a problem caused by this PR on A100s during training. Any idea what might be going on?
Reverting this PR makes it work again.

It complains about misaligned address as seen below:

[1,0]<stderr>:[2021-05-11 09:15:52] Error: CUDA error 716 'misaligned address' - /opt/marian-dev/src/tensors/gpu/algorithm.cu:15: cudaStreamSynchronize(0)
[1,0]<stderr>:[2021-05-11 09:15:52] Error: Aborted from void marian::gpu::copy(marian::Ptr<marian::Backend>, const T*, const T*, T*) [with T = unsigned int; marian::Ptr<marian::Backend> = std::shared_ptr<marian::Backend>] in /opt/marian-dev/src/tensors/gpu/algorithm.cu:15
[1,0]<stderr>:
[1,0]<stderr>:[CALL STACK]
[1,0]<stderr>:[0x55a5de0a6213]    void marian::gpu::  copy  <unsigned int>(std::shared_ptr<marian::Backend>,  unsigned int const*,  unsigned int const*,  unsigned int*) + 0x773
[1,0]<stderr>:[0x55a5dda7ad7c]    void marian::TensorBase::  set  <unsigned int>(unsigned int const*,  unsigned int const*) + 0x7bc
[1,0]<stderr>:[0x55a5dda7b074]    std::_Function_handler<void (IntrusivePtr<marian::TensorBase>),marian::inits::fromVector<unsigned int>(std::vector<unsigned int,std::allocator<unsigned int>> const&)::{lambda(IntrusivePtr<marian::TensorBase>)#1}>::  _M_invoke  (std::_Any_data const&,  IntrusivePtr<marian::TensorBase>&&) + 0x24
[1,0]<stderr>:[0x55a5dda73549]    marian::inits::LambdaInitConvert::  apply  (IntrusivePtr<marian::TensorBase>) + 0x79
[1,0]<stderr>:[0x55a5dda65131]    marian::ConstantNode::  init  ()                   + 0x51
[1,0]<stderr>:[0x55a5dda5543c]    marian::ExpressionGraph::  forward  (std::__cxx11::list<IntrusivePtr<marian::Chainable<IntrusivePtr<marian::TensorBase>>>,std::allocator<IntrusivePtr<marian::Chainable<IntrusivePtr<marian::TensorBase>>>>>&,  bool) + 0x7c
[1,0]<stderr>:[0x55a5dda56c5e]    marian::ExpressionGraph::  forwardNext  ()         + 0x2de
[1,0]<stderr>:[0x55a5ddc3b3f0]    marian::GraphGroup::  collectStats  (std::shared_ptr<marian::ExpressionGraph>,  std::shared_ptr<marian::models::ICriterionFunction>,  std::vector<std::shared_ptr<marian::Vocab>,std::allocator<std::shared_ptr<marian::Vocab>>> const&,  double) + 0xc20
[1,0]<stderr>:[0x55a5ddc25acc]    marian::SyncGraphGroup::  collectStats  (std::vector<std::shared_ptr<marian::Vocab>,std::allocator<std::shared_ptr<marian::Vocab>>> const&) + 0x11c
[1,0]<stderr>:[0x55a5dd88bbdd]    marian::Train<marian::SyncGraphGroup>::  run  ()   + 0x37d
[1,0]<stderr>:[0x55a5dd792532]    mainTrainer  (int,  char**)                        + 0xe2
[1,0]<stderr>:[0x55a5dd6ab42c]    main                                               + 0x3c
[1,0]<stderr>:[0x7f528a0a70b3]    __libc_start_main                                  + 0xf3
[1,0]<stderr>:[0x55a5dd79118e]    _start                                             + 0x2e

@rhenry-nv
Copy link
Contributor Author

rhenry-nv commented May 12, 2021

Is there some sort of slicing going on in this model? Which cuda version was this using? Is this half or float precision? What are the GEMM mnk dims and is this doing both relu + bias addition?

Also if it is possible, could you print the following?
C->data<float or half>() % 16
bias->data<float or half>() % 16

I came across a similar issue in 11.1 (I think) which is why the code path launches a separate kernel if the bias pointer is not a multiple of 8. However, I only checked this fix on volta and turing. A more recent version of cublas also fixes the bug on turing and volta. I will check Ampere when I get some time.

@emjotde
Copy link
Member

emjotde commented May 12, 2021

This will be CUDA 11.1. and fp32.

Currently that's all I can say until I manage to find out how to run a debug session on the new cluster. Will update here, hopefully soon.

Also this is the training path, so should not really run the combined Matmul + relu (we only run that during inference). I would guess this is somehow directly caused by Matmul only.

@rhenry-nv
Copy link
Contributor Author

We use this matmul + bias fusion (no relu) inside of training (I think) which is where I initially saw the issue.

As a quick check, you could set REQUIRED_BIAS_ALIGNMENT to 16. If I am correct, it should run without failure.

I have a couple guesses about what else could be wrong apart from this. I will take a look later this week to see if I can reproduce on an A100.

@emjotde
Copy link
Member

emjotde commented May 12, 2021

So, currently your guess would be that the bias might not be a multiple of 16? That might be the case for the output layer.

@rhenry-nv
Copy link
Contributor Author

Yep, that is currently my best guess.

@emjotde
Copy link
Member

emjotde commented May 17, 2021

static constexpr int REQUIRED_BIAS_ALIGNMENT = 16; 

Seems to do it. Are there any larger drawbacks?

@rhenry-nv
Copy link
Contributor Author

The drawback is that we will only get fusion when the bias is 16 byte aligned. Otherwise, we will fall back to the unfused code path which would be slightly slower.

If I remember right, the Marian allocator always gives 16-byte align pointers for the GPU backend (correct me if I'm wrong). If this is true, we will only take the slower path when we slice biases.

I will check a more recent CUDA version and file a bug internally if it still exists.

@emjotde
Copy link
Member

emjotde commented May 17, 2021

256-byte aligned actually. Some time ago it turned out that matmul is a lot faster if that kind of alignment is enforced. Not sure if that is still the case.

OK, I will change things to 16 then. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants