Skip to content

Commit

Permalink
fix compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
HandH1998 committed Jan 21, 2025
1 parent 1af98dc commit 9a5035a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
1 change: 1 addition & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def update_wheel_platform_tag():
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DNDEBUG",
]


Expand Down
23 changes: 18 additions & 5 deletions sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#pragma once

#include <cudaTypedefs.h>

#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>

Expand Down Expand Up @@ -33,6 +35,7 @@

using namespace cute;

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape,
typename WarpShape, int Stages, bool WithBias,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
Expand Down Expand Up @@ -301,7 +304,9 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
}
}
}
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
typename TileSchedulerType = void, bool WithBias = false>
Expand Down Expand Up @@ -398,7 +403,7 @@ struct DeviceGemmFp8RowwiseSm90
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;

using SlowAccum = DefaultSchedule;
using FastAccum = FastDefaultSchedule;
using FastAccum = FastPongSchedule; // Default apply Pingpong
using MainLoopSchedule = cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<ArchTag, OperatorClass, ElementA,
Expand Down Expand Up @@ -535,6 +540,7 @@ void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias);
}
}
#endif

torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
Expand Down Expand Up @@ -571,21 +577,28 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat

auto sm_version = getSMVersion();

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) {
if (out_dtype == torch::kBFloat16) {
sm90_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm90_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else if (sm_version == 89) {
return out;
}
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
if (sm_version == 89) {
if (out_dtype == torch::kBFloat16) {
sm89_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm89_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
} else {
return out;
}
#endif

TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
}

return out;
}

0 comments on commit 9a5035a

Please sign in to comment.