forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port oss f16_fast_gemv into fbcode (pytorch#3610)
Summary: X-link: facebookresearch/FBGEMM#688 This diff content includes: 1. Port OSS FastGEMV `fp16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.f16_fast_gemv` https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14 2. Add `fp16_oss_fast_gemv` to quantize ops benchmark script 3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test - `torch.compile()` able - correctness Perf numbers: P1720649201 compared with `f16_baseline,fp16_oss_fast_gemv,cuda_lite,marlin_bf16i4,machete_bf16i4` **Next step:** Need fp8 mixed precision support for fast gemv kernel which is what we want Differential Revision: D68470488
- Loading branch information
1 parent
73f6cee
commit 9c2cd76
Showing
9 changed files
with
604 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 Siping Wang | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/fp16_fast_gemv.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/core/ScalarType.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include "include/fast_gemv.cuh" | ||
|
||
namespace fbgemm_gpu { | ||
|
||
#if CUDART_VERSION >= 12000 | ||
|
||
at::Tensor fp16_fast_gemv(at::Tensor X, at::Tensor W) { | ||
// note: oss fast gemv implementation accepts vector shape as (size, 1) i.e. | ||
// (K, M) | ||
// X: K x M | ||
// W: N x K | ||
auto m = X.size(1); | ||
auto n = W.size(0); | ||
auto k = W.size(1); | ||
|
||
TORCH_CHECK(X.is_cuda() && X.is_contiguous()); | ||
TORCH_CHECK(W.is_cuda() && W.is_contiguous()); | ||
|
||
// the block_dim values are sweeped results from different problem sizes. | ||
// see tuning scripts: | ||
// src/quantize/fast_gemv/sweep_utils.py | ||
dim3 block_dim(32, 32); | ||
dim3 grid_dim(1, n / block_dim.y); | ||
unsigned int num_per_thread = k / block_dim.x; | ||
|
||
auto stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
auto out_sizes = X.sizes().vec(); | ||
out_sizes.front() = n; | ||
auto Y = at::empty(out_sizes, X.options().dtype(at::kHalf)); | ||
|
||
gemv_fp16<<<grid_dim, block_dim, 0, stream>>>( | ||
(half*)W.data_ptr(), // mat | ||
(half*)X.data_ptr(), // vec | ||
(half*)Y.data_ptr(), // res | ||
k, | ||
num_per_thread); | ||
|
||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
|
||
return Y; | ||
} | ||
|
||
#else | ||
|
||
at::Tensor fast_gemv(at::Tensor X, at::Tensor W) { | ||
throw std::runtime_error( | ||
"CUDA version is older than 12.0"); // requires CUDA>=12 | ||
} | ||
#endif | ||
|
||
} // namespace fbgemm_gpu |
Oops, something went wrong.