21
21
//
22
22
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
23
23
// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
24
+ // - Added proper architecture check at both host and device level
24
25
//
25
26
26
27
@@ -98,7 +99,24 @@ void fpx_linear_kernel(cudaStream_t stream,
98
99
static_assert (std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, " Type must be 'half' or '__nv_bfloat16'" );
99
100
assert (M_Global % 256 == 0 );
100
101
assert (K_Global % 64 == 0 );
101
- assert (N_Global>0 );
102
+ assert (N_Global > 0 );
103
+
104
+ // Check GPU Compute Capability before proceeding
105
+ int device, major, minor;
106
+ CHECK_CUDA (cudaGetDevice (&device));
107
+ CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
108
+ CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
109
+
110
+ // Early exit with error for unsupported architectures
111
+ if ((major < 7 ) || (major == 7 && minor < 5 )) {
112
+ TORCH_CHECK (false , " Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
113
+ " Your current device has SM" , major, minor, " which is not supported." );
114
+ }
115
+
116
+ const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
117
+ if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value) {
118
+ TORCH_CHECK (false , " Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs." );
119
+ }
102
120
103
121
// Work around to support more N shapes:
104
122
size_t N_PowerOf2;
@@ -109,17 +127,6 @@ void fpx_linear_kernel(cudaStream_t stream,
109
127
if (N_Global>64 && N_Global<=128 ) N_PowerOf2 = 128 ;
110
128
if (N_Global>128 ) N_PowerOf2 = ((N_Global-1 )/128 +1 ) * 128 ;
111
129
112
- // Check GPU Compute Capability
113
- int device, major, minor;
114
- CHECK_CUDA (cudaGetDevice (&device));
115
- CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
116
- CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
117
- const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
118
- if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
119
- TORCH_CHECK (false , " Bfloat16 inputs are not supported for SM75" );
120
- if ((major < 7 ) || (major == 7 && minor < 5 ))
121
- TORCH_CHECK (false , " FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n " );
122
-
123
130
if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0 )) {
124
131
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
125
132
if (Split_K == 1 ) {
@@ -136,7 +143,7 @@ void fpx_linear_kernel(cudaStream_t stream,
136
143
case 64 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
137
144
case 128 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
138
145
default : if (N_PowerOf2 % 128 != 0 ) {
139
- TORCH_CHECK (false , " FP6LLM_API Error: Unsupported N dimension " , N_PowerOf2);
146
+ TORCH_CHECK (false , " Quant-LLM Error: Unsupported N dimension " , N_PowerOf2);
140
147
}
141
148
Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
142
149
}
@@ -149,7 +156,7 @@ void fpx_linear_kernel(cudaStream_t stream,
149
156
case 64 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
150
157
case 128 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
151
158
default : if (N_PowerOf2 % 128 != 0 ) {
152
- TORCH_CHECK (false , " FP6LLM_API Error: Unsupported N dimension " , N_PowerOf2);
159
+ TORCH_CHECK (false , " Quant-LLM Error: Unsupported N dimension " , N_PowerOf2);
153
160
}
154
161
Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
155
162
}
@@ -210,6 +217,23 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
210
217
torch::Tensor _scales,
211
218
int64_t splitK=1 )
212
219
{
220
+ // Check GPU Compute Capability before proceeding
221
+ int device, major, minor;
222
+ CHECK_CUDA (cudaGetDevice (&device));
223
+ CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
224
+ CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
225
+
226
+ // Early exit with error for unsupported architectures
227
+ if ((major < 7 ) || (major == 7 && minor < 5 )) {
228
+ TORCH_CHECK (false , " Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
229
+ " Your current device has SM" , major, minor, " which is not supported." );
230
+ }
231
+
232
+ const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
233
+ if (is_sm75_gpu && _in_feats.scalar_type () == at::ScalarType::BFloat16) {
234
+ TORCH_CHECK (false , " Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs." );
235
+ }
236
+
213
237
const int64_t NBITS = 1 + EXPONENT + MANTISSA;
214
238
int num_in_feats = _in_feats.size (0 );
215
239
int num_in_channels = _in_feats.size (1 );
0 commit comments