Skip to content

Commit

Permalink
initial support for metal
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Feb 1, 2025
1 parent a1fc18c commit cd180dd
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 32 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ find_package(Torch REQUIRED)
#
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
if (VLLM_TARGET_DEVICE STREQUAL "cpu" OR VLLM_TARGET_DEVICE STREQUAL "metal")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else()
return()
Expand Down
8 changes: 8 additions & 0 deletions build_and_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

set -e
export PYTORCH_ENABLE_MPS_FALLBACK=1
export VLLM_TARGET_DEVICE=cpu
pip uninstall vllm
python setup.py install
vllm serve Qwen/Qwen2.5-0.5B-Instruct --dtype float16
6 changes: 6 additions & 0 deletions collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,12 @@ def is_xnnpack_available():
else:
return "N/A"

def is_mps_available():
if TORCH_AVAILABLE:
return str(torch.backends.mps.is_available())
else:
return "N/A"

def get_env_vars():
env_vars = ''
secret_terms=('secret', 'token', 'api', 'access', 'password')
Expand Down
14 changes: 14 additions & 0 deletions requirements-metal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Common dependencies
-r requirements-common.txt

# Dependencies for CPUs
torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin"

# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
torchaudio; platform_machine != "ppc64le"
torchaudio==2.5.1; platform_machine == "ppc64le"

# required for the image processor of phi3v, this must be updated alongside torch
torchvision; platform_machine != "ppc64le"
torchvision==0.20.1; platform_machine == "ppc64le"
datasets # for benchmark scripts
15 changes: 8 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ def load_module_from_path(module_name, path):

VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE

if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
logger.warning(
"VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
VLLM_TARGET_DEVICE = "cpu"
elif not (sys.platform.startswith("linux")
if not (sys.platform.startswith("linux")
or sys.platform.startswith("darwin")):
logger.warning(
"vLLM only supports Linux platform (including WSL) and MacOS."
Expand Down Expand Up @@ -390,10 +386,11 @@ def _is_openvino() -> bool:
def _is_xpu() -> bool:
return VLLM_TARGET_DEVICE == "xpu"

def _is_metal() -> bool:
return VLLM_TARGET_DEVICE == "metal"

def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()

return _is_cuda() or _is_hip() or _is_cpu() or _is_metal()

def get_rocm_version():
# Get the Rocm version from the ROCM_HOME/bin/librocm-core.so
Expand Down Expand Up @@ -521,6 +518,8 @@ def get_vllm_version() -> str:
version += f"{sep}cpu"
elif _is_xpu():
version += f"{sep}xpu"
elif _is_metal():
version += f"{sep}metal"
else:
raise RuntimeError("Unknown runtime environment")

Expand Down Expand Up @@ -581,6 +580,8 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-cpu.txt")
elif _is_xpu():
requirements = _read_requirements("requirements-xpu.txt")
elif _is_metal():
requirements = _read_requirements("requirements-metal.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, HPU, "
Expand Down
19 changes: 10 additions & 9 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
self._device = input_builder.device

def prepare(self):
self.input_data = self.input_builder.input_data
Expand All @@ -294,28 +295,28 @@ def build(self, seq_lens: List[int], query_lens: List[int],
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
device=self._device)

# For chunked-prefill
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
device=self._device,
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
device=self._device)
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
device=self._device)
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
device=self._device)
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
device=self._device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
Expand All @@ -338,20 +339,20 @@ def build(self, seq_lens: List[int], query_lens: List[int],
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
device=self._device,
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
device=self._device,
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
device=self._device,
)

# For multi-modal models
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,8 @@ def __init__(self, device: str = "auto") -> None:
# Some device types require processing inputs on CPU
if self.device_type in ["neuron", "openvino"]:
self.device = torch.device("cpu")
elif self.device_type in ["metal"]:
self.device = torch.device("mps")
elif self.device_type in ["tpu"]:
self.device = None
else:
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"tpu",
"xpu",
"hpu",
"metal",
]


Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, threshold: float = 0.):
self.threshold = threshold
if current_platform.is_cuda_alike():
self.op = torch.ops._C.fatrelu_and_mul
elif current_platform.is_cpu():
elif current_platform.is_cpu() or current_platform.is_metal():
self._forward_method = self.forward_native

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -63,7 +63,7 @@ class SiluAndMul(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_metal():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(self):
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul
elif current_platform.is_cpu():
elif current_platform.is_cpu() or current_platform.is_metal():
self._forward_method = self.forward_native

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(self, approximate: str = "none"):
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_metal():
if approximate == "none":
self.op = torch.ops._C.gelu_and_mul
elif approximate == "tanh":
Expand Down Expand Up @@ -182,7 +182,7 @@ class NewGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_metal():
self.op = torch.ops._C.gelu_new
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
Expand All @@ -208,7 +208,7 @@ class FastGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_metal():
self.op = torch.ops._C.gelu_fast
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
Expand All @@ -233,7 +233,7 @@ class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_metal():
self.op = torch.ops._C.gelu_quick
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
Expand Down
16 changes: 12 additions & 4 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,22 @@ def xpu_platform_plugin() -> Optional[str]:
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None


def metal_platform_plugin() -> Optional[str]:
is_metal = False
try:
from importlib.metadata import version
import torch
is_metal = torch.backends.mps.is_available() and "metal" in version("vllm")
except Exception:
pass
return "vllm.platforms.metal.MetalPlatform" if is_metal else None


def cpu_platform_plugin() -> Optional[str]:
is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
if not is_cpu:
import platform
is_cpu = platform.machine().lower().startswith("arm")

except Exception:
pass

Expand Down Expand Up @@ -142,6 +149,7 @@ def openvino_platform_plugin() -> Optional[str]:
'cpu': cpu_platform_plugin,
'neuron': neuron_platform_plugin,
'openvino': openvino_platform_plugin,
'metal': metal_platform_plugin,
}


Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
TORCH_SDPA_2 = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
Expand All @@ -44,6 +45,7 @@ class PlatformEnum(enum.Enum):
HPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
METAL = enum.auto()
NEURON = enum.auto()
OPENVINO = enum.auto()
OOT = enum.auto()
Expand Down Expand Up @@ -123,6 +125,9 @@ def is_xpu(self) -> bool:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU

def is_metal(self) -> bool:
return self._enum == PlatformEnum.METAL

def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON

Expand Down
98 changes: 98 additions & 0 deletions vllm/platforms/metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
from typing import TYPE_CHECKING, Optional

import psutil
import torch

from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None

logger = init_logger(__name__)


class MetalPlatform(Platform):
_enum = PlatformEnum.METAL
device_name: str = "metal"
device_type: str = "metal"
dispatch_key: str = "CPU"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "metal"

@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on Metal.", selected_backend)
logger.info("Using Torch SDPA backend.")
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False

@classmethod
def inference_mode(cls):
return torch.no_grad()

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
import vllm.envs as envs
from vllm.utils import GiB_bytes
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config

if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE

if kv_cache_space >= 0:
if kv_cache_space == 0:
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.")
else:
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
else:
raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")

scheduler_config = vllm_config.scheduler_config
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16 # TODO supported?

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"

assert vllm_config.device_config.device_type == "metal"

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Metal.")
return False

@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
Loading

0 comments on commit cd180dd

Please sign in to comment.