From 675e35bcd43d11d876d61be7fa2b37e751b922e1 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Thu, 23 Jan 2025 22:23:38 +0800 Subject: [PATCH] [tests] enable more bnb tests on XPU (#3350) * enable bnb tests * bug fix * enable more bnb tests on pxu * fix quality issue * furter fix quality * fix style --- tests/test_accelerator.py | 11 +++++++---- tests/test_big_modeling.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 0171ad072a3..b3b1744f343 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -30,8 +30,9 @@ from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, + require_cuda_or_xpu, require_huggingface_suite, - require_multi_gpu, + require_multi_device, require_non_cpu, require_transformer_engine, slow, @@ -452,7 +453,7 @@ def test_is_accelerator_prepared(self): getattr(valid_dl, "_is_accelerate_prepared", False) is True ), "Valid Dataloader is missing `_is_accelerator_prepared` or is set to `False`" - @require_cuda + @require_cuda_or_xpu @slow @require_bnb def test_accelerator_bnb(self): @@ -498,7 +499,7 @@ def test_accelerator_bnb_cpu_error(self): @require_non_torch_xla @slow @require_bnb - @require_multi_gpu + @require_multi_device def test_accelerator_bnb_multi_device(self): """Tests that the accelerator can be used with the BNB library.""" from transformers import AutoModelForCausalLM @@ -507,6 +508,8 @@ def test_accelerator_bnb_multi_device(self): PartialState._shared_state = {"distributed_type": DistributedType.MULTI_GPU} elif torch_device == "npu": PartialState._shared_state = {"distributed_type": DistributedType.MULTI_NPU} + elif torch_device == "xpu": + PartialState._shared_state = {"distributed_type": DistributedType.MULTI_XPU} else: raise ValueError(f"{torch_device} is not supported in test_accelerator_bnb_multi_device.") @@ -534,7 +537,7 @@ def test_accelerator_bnb_multi_device(self): @require_non_torch_xla @slow @require_bnb - @require_multi_gpu + @require_multi_device def test_accelerator_bnb_multi_device_no_distributed(self): """Tests that the accelerator can be used with the BNB library.""" from transformers import AutoModelForCausalLM diff --git a/tests/test_big_modeling.py b/tests/test_big_modeling.py index bb5bc5049ef..8595cdcc993 100644 --- a/tests/test_big_modeling.py +++ b/tests/test_big_modeling.py @@ -36,6 +36,7 @@ from accelerate.test_utils import ( require_bnb, require_cuda, + require_cuda_or_xpu, require_multi_device, require_multi_gpu, require_non_cpu, @@ -877,7 +878,7 @@ def test_cpu_offload_with_hook(self): @require_non_torch_xla @slow @require_bnb - @require_multi_gpu + @require_multi_device def test_dispatch_model_bnb(self): """Tests that `dispatch_model` quantizes int8 layers""" from huggingface_hub import hf_hub_download @@ -906,7 +907,7 @@ def test_dispatch_model_bnb(self): assert model.h[(-1)].self_attention.query_key_value.weight.dtype == torch.int8 assert model.h[(-1)].self_attention.query_key_value.weight.device.index == 1 - @require_cuda + @require_cuda_or_xpu @slow @require_bnb def test_dispatch_model_int8_simple(self): @@ -946,7 +947,7 @@ def test_dispatch_model_int8_simple(self): model = load_checkpoint_and_dispatch( model, checkpoint=model_path, - device_map={"": torch.device("cuda:0")}, + device_map={"": torch_device}, ) assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8 @@ -963,13 +964,13 @@ def test_dispatch_model_int8_simple(self): model = load_checkpoint_and_dispatch( model, checkpoint=model_path, - device_map={"": "cuda:0"}, + device_map={"": torch_device}, ) assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8 assert model.h[0].self_attention.query_key_value.weight.device.index == 0 - @require_cuda + @require_cuda_or_xpu @slow @require_bnb def test_dipatch_model_fp4_simple(self): @@ -1010,7 +1011,7 @@ def test_dipatch_model_fp4_simple(self): model = load_checkpoint_and_dispatch( model, checkpoint=model_path, - device_map={"": torch.device("cuda:0")}, + device_map={"": torch_device}, ) assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8 @@ -1027,7 +1028,7 @@ def test_dipatch_model_fp4_simple(self): model = load_checkpoint_and_dispatch( model, checkpoint=model_path, - device_map={"": "cuda:0"}, + device_map={"": torch_device}, ) assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8