Skip to content

Commit

Permalink
[Hardware] Initial TPU integration (vllm-project#5292)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jun 12, 2024
1 parent 847cdcc commit 1a8bfd9
Show file tree
Hide file tree
Showing 22 changed files with 1,322 additions and 28 deletions.
19 changes: 19 additions & 0 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
ARG NIGHTLY_DATE="20240601"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE

WORKDIR /workspace
COPY . /workspace/vllm

ENV VLLM_TARGET_DEVICE="tpu"
# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Build vLLM.
RUN cd /workspace/vllm && python setup.py develop

CMD ["/bin/bash"]
2 changes: 1 addition & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size',
type=int,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def main(args: argparse.Namespace):
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",
Expand Down
75 changes: 75 additions & 0 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
.. _installation_tpu:

Installation with TPU
=====================

vLLM supports Google Cloud TPUs using PyTorch XLA.

Requirements
------------

* Google Cloud TPU VM (single host)
* TPU versions: v5e, v5p, v4
* Python: 3.10

Installation options:

1. :ref:`Build a docker image with Dockerfile <build_docker_tpu>`.
2. :ref:`Build from source <build_from_source_tpu>`.

.. _build_docker_tpu:

Build a docker image with :code:`Dockerfile.tpu`
------------------------------------------------

`Dockerfile.tpu <https://github.com/vllm-project/vllm/blob/main/Dockerfile.tpu>`_ is provided to build a docker image with TPU support.

.. code-block:: console
$ docker build -f Dockerfile.tpu -t vllm-tpu .
You can run the docker image with the following command:

.. code-block:: console
$ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
.. _build_from_source_tpu:

Build from source
-----------------

You can also build and install the TPU backend from source.

First, install the dependencies:

.. code-block:: console
$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv
$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240601"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install packaging aiohttp
Next, build vLLM from source. This will only take a few seconds:

.. code-block:: console
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ Documentation

getting_started/installation
getting_started/amd-installation
getting_started/neuron-installation
getting_started/cpu-installation
getting_started/neuron-installation
getting_started/tpu-installation
getting_started/quickstart
getting_started/debugging
getting_started/examples/examples_index
Expand Down
7 changes: 7 additions & 0 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Common dependencies
-r requirements-common.txt

# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
triton # To avoid import errors
22 changes: 17 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def build_extensions(self) -> None:


def _is_cuda() -> bool:
return VLLM_TARGET_DEVICE == "cuda" \
and torch.version.cuda is not None \
and not _is_neuron()
has_cuda = torch.version.cuda is not None
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
and not (_is_neuron() or _is_tpu()))


def _is_hip() -> bool:
Expand All @@ -225,10 +225,18 @@ def _is_neuron() -> bool:
return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"


def _is_tpu() -> bool:
return VLLM_TARGET_DEVICE == "tpu"


def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"


def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()


def _install_punica() -> bool:
return envs.VLLM_INSTALL_PUNICA_KERNELS

Expand Down Expand Up @@ -325,6 +333,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_tpu():
version += "+tpu"
elif _is_cpu():
version += "+cpu"
else:
Expand Down Expand Up @@ -372,6 +382,8 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
elif _is_tpu():
requirements = _read_requirements("requirements-tpu.txt")
elif _is_cpu():
requirements = _read_requirements("requirements-cpu.txt")
else:
Expand All @@ -385,7 +397,7 @@ def _read_requirements(filename: str) -> List[str]:
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))

if not _is_neuron():
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))

if _install_punica():
Expand Down Expand Up @@ -428,6 +440,6 @@ def _read_requirements(filename: str) -> List[str]:
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data,
)
Loading

0 comments on commit 1a8bfd9

Please sign in to comment.