From 77c3c849212ec92dda86830286eb1ec74e3f6781 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 19 Jul 2024 14:51:47 +0800 Subject: [PATCH] fix setup.py --- setup.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 97a8612c..a39f6d40 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,18 @@ import os import os.path as osp import platform +import subprocess import sys from itertools import product import torch from setuptools import find_packages, setup from torch.__config__ import parallel_info -from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, - CUDAExtension) __version__ = '2.1.2' URL = 'https://github.com/rusty1s/pytorch_scatter' +CUDA_HOME = os.environ.get("CUDA_HOME", None) WITH_CUDA = False if torch.cuda.is_available(): WITH_CUDA = CUDA_HOME is not None or torch.version.hip @@ -28,6 +28,10 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' WITH_SYMBOLS = os.getenv('WITH_SYMBOLS', '0') == '1' +def get_cuda_bare_metal_version(cuda_home): + output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True).split() + release_idx = output.index("release") + return output[release_idx+1].split(",")[0].replace('.', '') def get_extensions(): extensions = [] @@ -108,7 +112,7 @@ def get_extensions(): install_requires = ["torch>=1.8.0"] -extra_index_url = ["https://download.pytorch.org/whl/"] +extra_index_url = ["https://download.pytorch.org/whl/cpu"] if suffices == ["cpu"] else [f"https://download.pytorch.org/whl/cu{get_cuda_bare_metal_version(CUDA_HOME)}"] test_requires = [ 'pytest',