diff --git a/build_and_run.sh b/build_and_run.sh index e28c2a4197d9b..bb84fb611fcb5 100755 --- a/build_and_run.sh +++ b/build_and_run.sh @@ -1,8 +1,9 @@ #!/bin/bash set -e +export VLLM_TARGET_DEVICE=metal export PYTORCH_ENABLE_MPS_FALLBACK=1 -export VLLM_TARGET_DEVICE=cpu +# export VLLM_TARGET_DEVICE=cpu pip uninstall vllm python setup.py install -vllm serve Qwen/Qwen2.5-0.5B-Instruct --dtype float16 +vllm serve Qwen/Qwen2.5-0.5B-Instruct diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 14722f5c34049..c0dee318e645c 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -500,6 +500,7 @@ def forward( PagedAttention.write_to_paged_cache( key, value, key_cache, value_cache, updated_slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale) + torch.mps.synchronize() if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -576,7 +577,7 @@ def forward( layer._k_scale, layer._v_scale, ) - + torch.mps.synchronize() # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 96995c56bf504..0fdd93d8088e8 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -88,6 +88,8 @@ def dispatch_forward(self): return self.forward_xpu elif current_platform.is_out_of_tree(): return self.forward_oot + elif current_platform.is_metal(): + return self.forward_native else: return self.forward_cuda diff --git a/vllm/platforms/metal.py b/vllm/platforms/metal.py index 47f5ea1c60d8a..6f109b3cafb2e 100644 --- a/vllm/platforms/metal.py +++ b/vllm/platforms/metal.py @@ -74,14 +74,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" f" {kv_cache_space}, expect a positive integer value.") - scheduler_config = vllm_config.scheduler_config - if ((scheduler_config.chunked_prefill_enabled - or cache_config.enable_prefix_caching) - and model_config.dtype == torch.half): - logger.warning("Chunked-prefill on the CPU backend only does not" - " support fp16 for now, cast to bf16.") - model_config.dtype = torch.bfloat16 # TODO supported? - parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" @@ -90,7 +82,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: @classmethod def is_pin_memory_available(cls) -> bool: - logger.warning("Pin memory is not supported on Metal.") return False @classmethod diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 868b8462a85b7..950f6cd2b6f2f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -40,6 +40,11 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.model_config = model_config self.parallel_config = parallel_config + if device_config.device_type == "metal": + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) @@ -77,7 +82,7 @@ def _allocate_kv_cache( kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): kv_cache.append( - torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + torch.empty(kv_cache_shape, dtype=self.dtype, device=self.device)) return kv_cache def swap_in(self, src_to_dst: Dict[int, int]) -> None: @@ -136,6 +141,13 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method + if self.device_config.device_type == "cpu": + self.device = torch.device("cpu") + elif self.device_config.device_type == "metal": + self.device = torch.device("mps") + else: + raise ValueError(f"Invalid device type: {self.device_config.device_type}") + self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -212,7 +224,6 @@ def init_device(self) -> None: ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) if ret: logger.info(ret) - self.device = torch.device("cpu") self.init_distributed_environment() # Set random seed. set_random_seed(self.model_config.seed) @@ -350,7 +361,7 @@ def prepare_worker_input( virtual_engine: int = execute_model_req.virtual_engine num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device="cpu", + device=self.device, dtype=torch.int64).view(-1, 2) assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0