Skip to content

Commit e4eff3a

Browse files
authored
Allow builds on less than sm75 raise runtime failure (#1999)
stack-info: PR: #1999, branch: drisspg/stack/45
1 parent 8776dd3 commit e4eff3a

File tree

3 files changed

+65
-19
lines changed

3 files changed

+65
-19
lines changed

torchao/csrc/cuda/fp6_llm/fp6_linear.cu

+38-14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
//
2222
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
2323
// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
24+
// - Added proper architecture check at both host and device level
2425
//
2526

2627

@@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream,
9899
static_assert(std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, "Type must be 'half' or '__nv_bfloat16'");
99100
assert(M_Global % 256 == 0);
100101
assert(K_Global % 64 == 0);
101-
assert(N_Global>0);
102+
assert(N_Global > 0);
103+
104+
// Check GPU Compute Capability before proceeding
105+
int device, major, minor;
106+
CHECK_CUDA(cudaGetDevice(&device));
107+
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
108+
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
109+
110+
// Early exit with error for unsupported architectures
111+
if ((major < 7) || (major == 7 && minor < 5)) {
112+
TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
113+
"Your current device has SM", major, minor, " which is not supported.");
114+
}
115+
116+
const bool is_sm75_gpu = (major == 7) && (minor == 5);
117+
if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value) {
118+
TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs.");
119+
}
102120

103121
// Work around to support more N shapes:
104122
size_t N_PowerOf2;
@@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream,
109127
if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128;
110128
if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128;
111129

112-
// Check GPU Compute Capability
113-
int device, major, minor;
114-
CHECK_CUDA(cudaGetDevice(&device));
115-
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
116-
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
117-
const bool is_sm75_gpu = (major == 7) && (minor == 5);
118-
if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
119-
TORCH_CHECK(false, "Bfloat16 inputs are not supported for SM75");
120-
if ((major < 7) || (major == 7 && minor < 5))
121-
TORCH_CHECK(false, "FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n");
122-
123130
if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) {
124131
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
125132
if (Split_K == 1) {
@@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream,
136143
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
137144
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
138145
default: if (N_PowerOf2 % 128 != 0) {
139-
TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2);
146+
TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2);
140147
}
141148
Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break;
142149
}
@@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream,
149156
case 64: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
150157
case 128: Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
151158
default: if (N_PowerOf2 % 128 != 0) {
152-
TORCH_CHECK(false, "FP6LLM_API Error: Unsupported N dimension ", N_PowerOf2);
159+
TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2);
153160
}
154161
Kernel_Ex<TilingConfig<4, 1, 8>, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break;
155162
}
@@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
210217
torch::Tensor _scales,
211218
int64_t splitK=1)
212219
{
220+
// Check GPU Compute Capability before proceeding
221+
int device, major, minor;
222+
CHECK_CUDA(cudaGetDevice(&device));
223+
CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
224+
CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
225+
226+
// Early exit with error for unsupported architectures
227+
if ((major < 7) || (major == 7 && minor < 5)) {
228+
TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
229+
"Your current device has SM", major, minor, " which is not supported.");
230+
}
231+
232+
const bool is_sm75_gpu = (major == 7) && (minor == 5);
233+
if (is_sm75_gpu && _in_feats.scalar_type() == at::ScalarType::BFloat16) {
234+
TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs.");
235+
}
236+
213237
const int64_t NBITS = 1 + EXPONENT + MANTISSA;
214238
int num_in_feats = _in_feats.size(0);
215239
int num_in_channels = _in_feats.size(1);

torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh

+14-5
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,14 @@
5151
* B: col major, FP16
5252
* C: col major, FP16
5353
*/
54-
template<typename TilingConfig, typename InputDataType, typename OutputDataType, int EXPONENT, int MANTISSA>
54+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
55+
template<typename TilingConfig, typename InputDataType, typename OutputDataType, int EXPONENT, int MANTISSA>
5556
__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
5657
const half *B,
5758
OutputDataType* C,
5859
const size_t M_Global, const size_t N_Global, const size_t K_Global,
5960
int Split_K)
6061
{
61-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
62-
static_assert(false, "Quant-LLM kernel: At least Turing generation (sm75) is required.");
63-
// __trap(); // fails at runtime instead of compile time
64-
#endif
6562
#ifdef DEBUG_MODE
6663
assert(K_Global%TilingConfig::TILE_K==0);
6764
assert(M_Global%TilingConfig::TILE_M==0);
@@ -233,3 +230,15 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
233230
}
234231
}
235232
}
233+
#else
234+
// Stub implementation for older architectures
235+
template<typename TilingConfig, typename InputDataType, typename OutputDataType, int EXPONENT, int MANTISSA>
236+
__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
237+
const half *B,
238+
OutputDataType* C,
239+
const size_t M_Global, const size_t N_Global, const size_t K_Global,
240+
int Split_K)
241+
{
242+
// NOOP, should never actually be called
243+
}
244+
#endif

torchao/ops.py

+13
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def decorator(func):
7171
return decorator
7272

7373

74+
@functools.lru_cache
75+
def cached_compute_capability():
76+
device_props = torch.cuda.get_device_properties(torch.cuda.current_device())
77+
compute_capability = device_props.major * 10 + device_props.minor
78+
return compute_capability
79+
80+
7481
def quant_llm_linear(
7582
EXPONENT: int,
7683
MANTISSA: int,
@@ -93,6 +100,12 @@ def quant_llm_linear(
93100
Returns
94101
output of linear layer
95102
"""
103+
# Check if we're on a supported architecture (sm7.5 or higher)
104+
compute_capability = cached_compute_capability()
105+
torch._check(
106+
compute_capability >= 75,
107+
lambda: f"quant_llm_linear requires sm7.5+ GPU architecture, but current device has sm{compute_capability}",
108+
)
96109
return torch.ops.torchao.quant_llm_linear.default(
97110
EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK
98111
)

0 commit comments

Comments
 (0)