Skip to content

Commit

Permalink
fix setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
season0528 committed Jul 19, 2024
1 parent 6e21003 commit 77c3c84
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import os
import os.path as osp
import platform
import subprocess
import sys
from itertools import product

import torch
from setuptools import find_packages, setup
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)

__version__ = '2.1.2'
URL = 'https://github.com/rusty1s/pytorch_scatter'

CUDA_HOME = os.environ.get("CUDA_HOME", None)
WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
Expand All @@ -28,6 +28,10 @@
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_SYMBOLS = os.getenv('WITH_SYMBOLS', '0') == '1'

def get_cuda_bare_metal_version(cuda_home):
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True).split()
release_idx = output.index("release")
return output[release_idx+1].split(",")[0].replace('.', '')

def get_extensions():
extensions = []
Expand Down Expand Up @@ -108,7 +112,7 @@ def get_extensions():


install_requires = ["torch>=1.8.0"]
extra_index_url = ["https://download.pytorch.org/whl/"]
extra_index_url = ["https://download.pytorch.org/whl/cpu"] if suffices == ["cpu"] else [f"https://download.pytorch.org/whl/cu{get_cuda_bare_metal_version(CUDA_HOME)}"]

test_requires = [
'pytest',
Expand Down

0 comments on commit 77c3c84

Please sign in to comment.