Skip to content

Commit

Permalink
[tests] enable more bnb tests on XPU (#3350)
Browse files Browse the repository at this point in the history
* enable bnb tests

* bug fix

* enable more bnb tests on pxu

* fix quality issue

* furter fix quality

* fix style
  • Loading branch information
faaany authored Jan 23, 2025
1 parent 8f2d31c commit 675e35b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
11 changes: 7 additions & 4 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
require_cuda_or_xpu,
require_huggingface_suite,
require_multi_gpu,
require_multi_device,
require_non_cpu,
require_transformer_engine,
slow,
Expand Down Expand Up @@ -452,7 +453,7 @@ def test_is_accelerator_prepared(self):
getattr(valid_dl, "_is_accelerate_prepared", False) is True
), "Valid Dataloader is missing `_is_accelerator_prepared` or is set to `False`"

@require_cuda
@require_cuda_or_xpu
@slow
@require_bnb
def test_accelerator_bnb(self):
Expand Down Expand Up @@ -498,7 +499,7 @@ def test_accelerator_bnb_cpu_error(self):
@require_non_torch_xla
@slow
@require_bnb
@require_multi_gpu
@require_multi_device
def test_accelerator_bnb_multi_device(self):
"""Tests that the accelerator can be used with the BNB library."""
from transformers import AutoModelForCausalLM
Expand All @@ -507,6 +508,8 @@ def test_accelerator_bnb_multi_device(self):
PartialState._shared_state = {"distributed_type": DistributedType.MULTI_GPU}
elif torch_device == "npu":
PartialState._shared_state = {"distributed_type": DistributedType.MULTI_NPU}
elif torch_device == "xpu":
PartialState._shared_state = {"distributed_type": DistributedType.MULTI_XPU}
else:
raise ValueError(f"{torch_device} is not supported in test_accelerator_bnb_multi_device.")

Expand Down Expand Up @@ -534,7 +537,7 @@ def test_accelerator_bnb_multi_device(self):
@require_non_torch_xla
@slow
@require_bnb
@require_multi_gpu
@require_multi_device
def test_accelerator_bnb_multi_device_no_distributed(self):
"""Tests that the accelerator can be used with the BNB library."""
from transformers import AutoModelForCausalLM
Expand Down
15 changes: 8 additions & 7 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from accelerate.test_utils import (
require_bnb,
require_cuda,
require_cuda_or_xpu,
require_multi_device,
require_multi_gpu,
require_non_cpu,
Expand Down Expand Up @@ -877,7 +878,7 @@ def test_cpu_offload_with_hook(self):
@require_non_torch_xla
@slow
@require_bnb
@require_multi_gpu
@require_multi_device
def test_dispatch_model_bnb(self):
"""Tests that `dispatch_model` quantizes int8 layers"""
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -906,7 +907,7 @@ def test_dispatch_model_bnb(self):
assert model.h[(-1)].self_attention.query_key_value.weight.dtype == torch.int8
assert model.h[(-1)].self_attention.query_key_value.weight.device.index == 1

@require_cuda
@require_cuda_or_xpu
@slow
@require_bnb
def test_dispatch_model_int8_simple(self):
Expand Down Expand Up @@ -946,7 +947,7 @@ def test_dispatch_model_int8_simple(self):
model = load_checkpoint_and_dispatch(
model,
checkpoint=model_path,
device_map={"": torch.device("cuda:0")},
device_map={"": torch_device},
)

assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8
Expand All @@ -963,13 +964,13 @@ def test_dispatch_model_int8_simple(self):
model = load_checkpoint_and_dispatch(
model,
checkpoint=model_path,
device_map={"": "cuda:0"},
device_map={"": torch_device},
)

assert model.h[0].self_attention.query_key_value.weight.dtype == torch.int8
assert model.h[0].self_attention.query_key_value.weight.device.index == 0

@require_cuda
@require_cuda_or_xpu
@slow
@require_bnb
def test_dipatch_model_fp4_simple(self):
Expand Down Expand Up @@ -1010,7 +1011,7 @@ def test_dipatch_model_fp4_simple(self):
model = load_checkpoint_and_dispatch(
model,
checkpoint=model_path,
device_map={"": torch.device("cuda:0")},
device_map={"": torch_device},
)

assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8
Expand All @@ -1027,7 +1028,7 @@ def test_dipatch_model_fp4_simple(self):
model = load_checkpoint_and_dispatch(
model,
checkpoint=model_path,
device_map={"": "cuda:0"},
device_map={"": torch_device},
)

assert model.h[0].self_attention.query_key_value.weight.dtype == torch.uint8
Expand Down

0 comments on commit 675e35b

Please sign in to comment.