Skip to content

Commit f1c7574

Browse files
authored
Tests: improve CUDA support detection (#985)
* implicitly skip any test that implicitly uses CUDA on a non-CUDA box * add a `requires_cuda` fixture
1 parent 53f8af8 commit f1c7574

File tree

8 files changed

+26
-14
lines changed

8 files changed

+26
-14
lines changed

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
import torch
3+
4+
5+
def pytest_runtest_call(item):
6+
try:
7+
item.runtest()
8+
except AssertionError as ae:
9+
if str(ae) == "Torch not compiled with CUDA enabled":
10+
pytest.skip("Torch not compiled with CUDA enabled")
11+
raise
12+
13+
14+
@pytest.fixture(scope="session")
15+
def requires_cuda() -> bool:
16+
cuda_available = torch.cuda.is_available()
17+
if not cuda_available:
18+
pytest.skip("CUDA is required")
19+
return cuda_available

tests/test_autograd.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
ids=names,
4141
)
4242
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
43-
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
4443
if dim2 > 0:
4544
dim2 = dim2 - (dim2 % 16)
4645
dim3 = dim3 - (dim3 % 16)
@@ -307,7 +306,6 @@ def test_matmullt(
307306
has_fp16_weights,
308307
has_bias
309308
):
310-
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
311309
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
312310
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
313311
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
@@ -461,7 +459,6 @@ def test_matmullt(
461459
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
462460
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
463461
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
464-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
465462
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
466463
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
467464
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
@@ -551,7 +548,6 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
551548
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
552549
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
553550
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
554-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
555551
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
556552
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
557553
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)

tests/test_cuda_setup_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
# hardcoded test. Not good, but a sanity check for now
77
# TODO: improve this
8-
def test_manual_override():
8+
def test_manual_override(requires_cuda):
99
manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2'))
1010

1111
pytorch_version = torch.version.cuda.replace('.', '')
1212

13-
assert pytorch_version != 122
13+
assert pytorch_version != 122 # TODO: this will never be true...
1414

1515
os.environ['CUDA_HOME']='{manual_cuda_path}'
1616
os.environ['BNB_CUDA_VERSION']='122'

tests/test_functional.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,10 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
617617
return
618618
if dtype == torch.int32 and out_order != "col32":
619619
return
620-
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
620+
try:
621+
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
622+
except ValueError as ve:
623+
pytest.skip(str(ve)) # skip if not supported
621624

622625
if dims == 2:
623626
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
@@ -2278,7 +2281,6 @@ def test_fp4_quant(dtype):
22782281
assert relerr.item() < 0.28
22792282

22802283

2281-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
22822284
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
22832285
def test_4bit_compressed_stats(quant_type):
22842286
for blocksize in [128, 64]:
@@ -2317,7 +2319,6 @@ def test_4bit_compressed_stats(quant_type):
23172319

23182320

23192321

2320-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
23212322
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
23222323
@pytest.mark.parametrize("quant_type", ['nf4'])
23232324
def test_bench_4bit_dequant(quant_type):

tests/test_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def model_and_tokenizer(request):
7979
@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False'])
8080
@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False'])
8181
#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
82-
def test_pi(model_and_tokenizer, inference_kernel, DQ):
82+
def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ):
8383
print('')
8484
dtype = torch.float16
8585

tests/test_linear4bit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
'float32': torch.float32
1616
}
1717

18-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
1918
@pytest.mark.parametrize(
2019
"quant_type, compress_statistics, bias, quant_storage",
2120
list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),

tests/test_linear8bitlt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def test_layout_exact_match():
3333
assert torch.all(torch.eq(restored_x, x))
3434

3535

36-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
3736
def test_linear_no_igemmlt():
3837
linear = torch.nn.Linear(1024, 3072)
3938
x = torch.randn(3, 1024, dtype=torch.half)
@@ -68,7 +67,6 @@ def test_linear_no_igemmlt():
6867
assert linear_custom.state.CxB is None
6968

7069

71-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
7270
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
7371
list(product([False, True], [False, True], [False, True], [False, True])))
7472
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):

tests/test_modules.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ def test_linear_kbit_fp32_bias(module):
520520
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
521521
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
522522
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
523-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
524523
@pytest.mark.parametrize("module", modules, ids=names)
525524
def test_kbit_backprop(module):
526525
b = 17

0 commit comments

Comments
 (0)