Skip to content

Commit

Permalink
ensure all tensors are on the same device
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Feb 1, 2025
1 parent cd180dd commit 40b1bb7
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 15 deletions.
5 changes: 3 additions & 2 deletions build_and_run.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/bin/bash

set -e
export VLLM_TARGET_DEVICE=metal
export PYTORCH_ENABLE_MPS_FALLBACK=1
export VLLM_TARGET_DEVICE=cpu
# export VLLM_TARGET_DEVICE=cpu
pip uninstall vllm
python setup.py install
vllm serve Qwen/Qwen2.5-0.5B-Instruct --dtype float16
vllm serve Qwen/Qwen2.5-0.5B-Instruct
3 changes: 2 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def forward(
PagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
torch.mps.synchronize()

if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
Expand Down Expand Up @@ -576,7 +577,7 @@ def forward(
layer._k_scale,
layer._v_scale,
)

torch.mps.synchronize()
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def dispatch_forward(self):
return self.forward_xpu
elif current_platform.is_out_of_tree():
return self.forward_oot
elif current_platform.is_metal():
return self.forward_native
else:
return self.forward_cuda

Expand Down
9 changes: 0 additions & 9 deletions vllm/platforms/metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")

scheduler_config = vllm_config.scheduler_config
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16 # TODO supported?

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
Expand All @@ -90,7 +82,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Metal.")
return False

@classmethod
Expand Down
17 changes: 14 additions & 3 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
self.model_config = model_config
self.parallel_config = parallel_config

if device_config.device_type == "metal":
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")

self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
Expand Down Expand Up @@ -77,7 +82,7 @@ def _allocate_kv_cache(
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
torch.empty(kv_cache_shape, dtype=self.dtype, device=self.device))
return kv_cache

def swap_in(self, src_to_dst: Dict[int, int]) -> None:
Expand Down Expand Up @@ -136,6 +141,13 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method

if self.device_config.device_type == "cpu":
self.device = torch.device("cpu")
elif self.device_config.device_type == "metal":
self.device = torch.device("mps")
else:
raise ValueError(f"Invalid device type: {self.device_config.device_type}")

self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
Expand Down Expand Up @@ -212,7 +224,6 @@ def init_device(self) -> None:
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
self.device = torch.device("cpu")
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
Expand Down Expand Up @@ -350,7 +361,7 @@ def prepare_worker_input(
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
device=self.device,
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
Expand Down

0 comments on commit 40b1bb7

Please sign in to comment.