Skip to content

Commit

Permalink
update extension module
Browse files Browse the repository at this point in the history
  • Loading branch information
gyzhou2000 committed Jul 16, 2024
1 parent 3ce4e17 commit d26014a
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import os
import os.path as osp
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
from ggl_build_extension import PyCudaExtension, PyCPUExtension
# from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
# from ggl_build_extension import PyCudaExtension, PyCPUExtension
from tensorlayerx.utils import PyCppExtension, PyCUDAExtension, PyBuildExtension

# TODO will depend on different host
WITH_CUDA = False
# WITH_CUDA = True

cuda_macro = ('COMPILE_WITH_CUDA', True)
omp_macro = ('COMPLIE_WITH_OMP', True) # Note: OpenMP needs gcc>4.2.0
omp_macro = ('COMPILE_WITH_OMP', True) # Note: OpenMP needs gcc>4.2.0
compile_args = {
'cxx': ['-fopenmp', '-std=c++17']
}
Expand Down Expand Up @@ -45,20 +46,22 @@ def load_mpops_extensions():
file_list.extend([osp.join(src_dir, f) for f in src_files])

if not WITH_CUDA:
extensions.append(CppExtension(
extensions.append(PyCppExtension(
name=osp.join(mpops_dir, f'_{mpops_prefix}').replace(osp.sep, "."),
sources=[f for f in file_list],
extra_compile_args=compile_args
extra_compile_args=compile_args,
use_torch=True
))
else:
extensions.append(CUDAExtension(
extensions.append(PyCUDAExtension(
name=osp.join(mpops_dir, f'_{mpops_prefix}').replace(osp.sep, "."),
sources=[f for f in file_list],
define_macros=[
cuda_macro,
omp_macro
],
extra_compile_args=compile_args
extra_compile_args=compile_args,
use_torch=True
))

return extensions
Expand Down Expand Up @@ -91,14 +94,14 @@ def load_ops_extensions():
if not src_files:
continue
if not is_cuda_ext:
extensions.append(PyCPUExtension(
extensions.append(PyCppExtension(
name=osp.join(ops_dir, f'_{ops_prefix}').replace(osp.sep, "."),
sources=[osp.join(src_dir, f) for f in src_files],
include_dirs=[osp.abspath(osp.join('third_party', d)) for d in ops_third_party_deps[i]],
extra_compile_args=['-std=c++17']
))
else:
extensions.append(PyCudaExtension(
extensions.append(PyCUDAExtension(
name=osp.join(ops_dir, f'_{ops_prefix}_cuda').replace(osp.sep, "."),
sources=[osp.join(src_dir, f) for f in src_files],
include_dirs=[osp.abspath(osp.join('third_party', d)) for d in ops_third_party_deps[i]],
Expand Down Expand Up @@ -136,7 +139,7 @@ def readme():
author_email="[email protected]",
maintainer="Tianyu Zhao",
license="Apache-2.0 License",
cmdclass={'build_ext': BuildExtension},
cmdclass={'build_ext': PyBuildExtension},
ext_modules=load_extensions(),
description=" ",
long_description=readme(),
Expand Down

0 comments on commit d26014a

Please sign in to comment.