Skip to content

Commit

Permalink
add cuda arm to build matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
tinglvv committed Jul 1, 2024
1 parent 913d92b commit 8e4aedf
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tools/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
# Accelerator architectures
CPU = "cpu"
CPU_AARCH64 = "cpu-aarch64"
CUDA_AARCH64 = "cuda-aarch64"
CUDA = "cuda"
ROCM = "rocm"

Expand Down Expand Up @@ -103,6 +104,8 @@ def arch_type(arch_version: str) -> str:
return ROCM
elif arch_version == CPU_AARCH64:
return CPU_AARCH64
elif arch_version == CUDA_AARCH64:
return CUDA_AARCH64
else: # arch_version should always be CPU in this case
return CPU

Expand Down Expand Up @@ -154,6 +157,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None:
},
CPU: "pytorch/manylinux-builder:cpu",
CPU_AARCH64: "pytorch/manylinuxaarch64-builder:cpu-aarch64",
CUDA_AARCH64: "pytorch/manylinuxaarch64-builder:cuda12.4",
}
CONDA_CONTAINER_IMAGES = {
**{
Expand Down Expand Up @@ -188,6 +192,7 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
return {
CPU: "cpu",
CPU_AARCH64: CPU,
CUDA_AARCH64: "cu124",
CUDA: f"cu{gpu_arch_version.replace('.', '')}",
ROCM: f"rocm{gpu_arch_version}",
}.get(gpu_arch_type, gpu_arch_version)
Expand Down Expand Up @@ -490,7 +495,7 @@ def generate_wheels_matrix(
if os == LINUX_AARCH64:
# Only want the one arch as the CPU type is different and
# uses different build/test scripts
arches = [CPU_AARCH64]
arches = [CPU_AARCH64, CUDA_AARCH64]

if with_cuda == ENABLE:
upload_to_base_bucket = "no"
Expand Down

0 comments on commit 8e4aedf

Please sign in to comment.