diff --git a/.gitignore b/.gitignore index 3035a805b12..466d73f7465 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,8 @@ bazel-* # Clangd cache directory .cache/* + + +# DISC outputs +disc_compiler_main +torch-mlir-opt diff --git a/bazel/disc.BUILD b/bazel/disc.BUILD index 7ed33604618..1971f399dc7 100644 --- a/bazel/disc.BUILD +++ b/bazel/disc.BUILD @@ -35,9 +35,10 @@ cc_import( genrule( name = "build_disc", + srcs = glob(["third_party/BladeDISC/**"]), outs = ["libral_base_context.so", "libdisc_custom_ops.so", "disc_compiler_main", "torch-mlir-opt"], local = True, - cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}', + cmd = '&&'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}', 'pushd external/disc_compiler/pytorch_blade/', 'python ../scripts/python/common_setup.py', 'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel', @@ -46,4 +47,5 @@ genrule( 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so $(location libdisc_custom_ops.so)', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)', 'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']), + visibility = ["//visibility:public"], ) diff --git a/bazel/flash_attn.BUILD b/bazel/flash_attn.BUILD index 6be811b826b..c51a563e7fb 100644 --- a/bazel/flash_attn.BUILD +++ b/bazel/flash_attn.BUILD @@ -21,10 +21,11 @@ cc_import( genrule( name = "build_flash_attn", - srcs = ["setup.py"], + srcs = glob(["third_party/flash-attention/**"]), outs = ["flash_attn_cuda.so"], - cmd = ';'.join(['pushd external/flash_attn/', + cmd = '&&'.join(['pushd external/flash_attn/', 'FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel', 'popd', - 'cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)']), + 'cp external/flash_attn/build/*/*.so $(OUTS)']), + visibility = ["//visibility:public"], ) diff --git a/setup.py b/setup.py index 3531a61e211..aad8b3c1205 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ import contextlib import distutils.ccompiler import distutils.command.clean +import glob import os import requests import shutil @@ -245,26 +246,30 @@ def bazel_build(self, ext): os.system(f"patchelf --add-rpath '$ORIGIN/' {ext_bazel_bin_path}") shutil.copyfile(ext_bazel_bin_path, ext_dest_path) + def copyfiles(bazel_bin_path, file_name_list): + if not isinstance(file_name_list, list): + file_name_list = [file_name_list] + + for file_name in file_name_list: + src = glob.glob(os.path.join(bazel_bin_path, file_name)) + assert len(src) == 1 + dest = '/'.join([ext_dest_dir, file_name]) + shutil.copy2(src[0], dest) + # copy flash attention cuda so file - flash_attn_so_name = 'flash_attn_cuda.so' - bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/flash_attn/' - shutil.copyfile('/'.join([bazel_bin_path, flash_attn_so_name]), - '/'.join([ext_dest_dir, flash_attn_so_name])) + copyfiles( + 'build/*/bazel-bin/external/flash_attn/', + ['flash_attn_cuda.so']) # package BladeDISC distribution files # please note, TorchBlade also create some symbolic links to 'torch_blade' dir if build_util.check_env_flag('ENABLE_DISC', 'false'): - disc_ral_so_name = 'libral_base_context.so' - bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' - shutil.copyfile( - os.path.join(bazel_bin_path, disc_ral_so_name), - '/'.join([ext_dest_dir, disc_ral_so_name])) - - disc_customop_so_name = 'libdisc_custom_ops.so' - bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler' - shutil.copyfile( - os.path.join(bazel_bin_path, disc_customop_so_name), - '/'.join([ext_dest_dir, disc_customop_so_name])) + copyfiles( + 'build/*/bazel-bin/external/disc_compiler', + [ + 'libral_base_context.so', 'libdisc_custom_ops.so', + 'disc_compiler_main', 'torch-mlir-opt' + ]) class Develop(develop.develop):