From 40b1bb768c15d3347dbc87796b1e034051e2d811 Mon Sep 17 00:00:00 2001
From: Alex Chi <iskyzh@gmail.com>
Date: Sat, 1 Feb 2025 00:11:38 -0500
Subject: [PATCH] ensure all tensors are on the same device

Signed-off-by: Alex Chi <iskyzh@gmail.com>
---
 build_and_run.sh                      |  5 +++--
 vllm/attention/backends/torch_sdpa.py |  3 ++-
 vllm/model_executor/custom_op.py      |  2 ++
 vllm/platforms/metal.py               |  9 ---------
 vllm/worker/cpu_worker.py             | 17 ++++++++++++++---
 5 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/build_and_run.sh b/build_and_run.sh
index e28c2a4197d9b..bb84fb611fcb5 100755
--- a/build_and_run.sh
+++ b/build_and_run.sh
@@ -1,8 +1,9 @@
 #!/bin/bash
 
 set -e
+export VLLM_TARGET_DEVICE=metal
 export PYTORCH_ENABLE_MPS_FALLBACK=1
-export VLLM_TARGET_DEVICE=cpu
+# export VLLM_TARGET_DEVICE=cpu
 pip uninstall vllm
 python setup.py install
-vllm serve Qwen/Qwen2.5-0.5B-Instruct --dtype float16
+vllm serve Qwen/Qwen2.5-0.5B-Instruct
diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py
index 14722f5c34049..c0dee318e645c 100644
--- a/vllm/attention/backends/torch_sdpa.py
+++ b/vllm/attention/backends/torch_sdpa.py
@@ -500,6 +500,7 @@ def forward(
                 PagedAttention.write_to_paged_cache(
                     key, value, key_cache, value_cache, updated_slot_mapping,
                     self.kv_cache_dtype, layer._k_scale, layer._v_scale)
+                torch.mps.synchronize()
 
         if attn_type != AttentionType.ENCODER:
             # Decoder self-attention supports chunked prefill.
@@ -576,7 +577,7 @@ def forward(
                 layer._k_scale,
                 layer._v_scale,
             )
-
+        torch.mps.synchronize()
         # Reshape the output tensor.
         return output.view(-1, self.num_heads * self.head_size)
 
diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py
index 96995c56bf504..0fdd93d8088e8 100644
--- a/vllm/model_executor/custom_op.py
+++ b/vllm/model_executor/custom_op.py
@@ -88,6 +88,8 @@ def dispatch_forward(self):
             return self.forward_xpu
         elif current_platform.is_out_of_tree():
             return self.forward_oot
+        elif current_platform.is_metal():
+            return self.forward_native
         else:
             return self.forward_cuda
 
diff --git a/vllm/platforms/metal.py b/vllm/platforms/metal.py
index 47f5ea1c60d8a..6f109b3cafb2e 100644
--- a/vllm/platforms/metal.py
+++ b/vllm/platforms/metal.py
@@ -74,14 +74,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
                 "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
                 f" {kv_cache_space}, expect a positive integer value.")
 
-        scheduler_config = vllm_config.scheduler_config
-        if ((scheduler_config.chunked_prefill_enabled
-             or cache_config.enable_prefix_caching)
-                and model_config.dtype == torch.half):
-            logger.warning("Chunked-prefill on the CPU backend only does not"
-                           " support fp16 for now, cast to bf16.")
-            model_config.dtype = torch.bfloat16 # TODO supported?
-        
         parallel_config = vllm_config.parallel_config
         if parallel_config.worker_cls == "auto":
             parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
@@ -90,7 +82,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
 
     @classmethod
     def is_pin_memory_available(cls) -> bool:
-        logger.warning("Pin memory is not supported on Metal.")
         return False
 
     @classmethod
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 868b8462a85b7..950f6cd2b6f2f 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -40,6 +40,11 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
         self.model_config = model_config
         self.parallel_config = parallel_config
 
+        if device_config.device_type == "metal":
+            self.device = torch.device("mps")
+        else:
+            self.device = torch.device("cpu")
+
         self.head_size = model_config.get_head_size()
         self.num_layers = model_config.get_num_layers(parallel_config)
         self.num_heads = model_config.get_num_kv_heads(parallel_config)
@@ -77,7 +82,7 @@ def _allocate_kv_cache(
         kv_cache: List[torch.Tensor] = []
         for _ in range(self.num_layers):
             kv_cache.append(
-                torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
+                torch.empty(kv_cache_shape, dtype=self.dtype, device=self.device))
         return kv_cache
 
     def swap_in(self, src_to_dst: Dict[int, int]) -> None:
@@ -136,6 +141,13 @@ def __init__(
         self.rank = rank
         self.distributed_init_method = distributed_init_method
 
+        if self.device_config.device_type == "cpu":
+            self.device = torch.device("cpu")
+        elif self.device_config.device_type == "metal":
+            self.device = torch.device("mps")
+        else:
+            raise ValueError(f"Invalid device type: {self.device_config.device_type}")
+
         self.is_driver_worker = is_driver_worker
         if self.is_driver_worker:
             assert self.rank == 0, "The driver worker must have rank 0."
@@ -212,7 +224,6 @@ def init_device(self) -> None:
             ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
             if ret:
                 logger.info(ret)
-        self.device = torch.device("cpu")
         self.init_distributed_environment()
         # Set random seed.
         set_random_seed(self.model_config.seed)
@@ -350,7 +361,7 @@ def prepare_worker_input(
         virtual_engine: int = execute_model_req.virtual_engine
         num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
         blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
-                                      device="cpu",
+                                      device=self.device,
                                       dtype=torch.int64).view(-1, 2)
         assert len(execute_model_req.blocks_to_swap_in) == 0
         assert len(execute_model_req.blocks_to_swap_out) == 0