-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsetup_ops.py
41 lines (34 loc) · 1.24 KB
/
setup_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import paddle
from paddle.utils.cpp_extension import CppExtension
from paddle.utils.cpp_extension import CUDAExtension
from paddle.utils.cpp_extension import setup
def get_sources():
csrc_dir_path = os.path.join(os.path.dirname(__file__), "csrc")
cpp_files = []
for item in os.listdir(csrc_dir_path):
if paddle.device.is_compiled_with_cuda():
if item.endswith(".cc") or item.endswith(".cu"):
cpp_files.append(os.path.join(csrc_dir_path, item))
else:
if item.endswith(".cc"):
cpp_files.append(os.path.join(csrc_dir_path, item))
return csrc_dir_path, cpp_files
def get_extensions():
src = get_sources()
Extension = CUDAExtension if paddle.device.is_compiled_with_cuda() else CppExtension
ext_modules = [
Extension(
sources=src[1],
include_dirs=src[0],
)
]
return ext_modules
setup(
name="paddle_scatter_ops",
version="1.0",
author="NKNaN",
url="https://github.com/PFCCLab/paddle_scatter",
description="Paddle extension of scatter and segment operators with min and max reduction methods, originally from https://github.com/rusty1s/pytorch_scatter",
ext_modules=get_extensions(),
)