Skip to content

Commit

Permalink
Don't build punica kernels by default (vllm-project#2605)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz authored Jan 26, 2024
1 parent 3a0e1fc commit 390b495
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt

# Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1
# Make sure punica is built for the release (for LoRA)
export VLLM_INSTALL_PUNICA_KERNELS=1

# Build
$python_executable setup.py bdist_wheel --dist-dir=dist
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc
ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1

RUN python3 setup.py build_ext --inplace
#################### EXTENSION Build IMAGE ####################
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def get_torch_arch_list() -> Set[str]:
with contextlib.suppress(ValueError):
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)

install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1")))
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
Expand Down
9 changes: 6 additions & 3 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,13 @@ def _raise_exc(
**kwargs # pylint: disable=unused-argument
):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError(
"LoRA kernels require compute capability>=8.0") from import_exc
raise ImportError("punica LoRA kernels require compute "
"capability>=8.0") from import_exc
else:
raise import_exc
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from import_exc

bgmv = _raise_exc
add_lora = _raise_exc
Expand Down

0 comments on commit 390b495

Please sign in to comment.