Skip to content

Commit

Permalink
Update DISC and FA bazel building scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
anw90 committed Oct 22, 2024
1 parent fab18e0 commit 8037f43
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ bazel-*

# Clangd cache directory
.cache/*


# DISC outputs
disc_compiler_main
torch-mlir-opt
4 changes: 3 additions & 1 deletion bazel/disc.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ cc_import(

genrule(
name = "build_disc",
srcs = glob(["third_party/BladeDISC/**"]),
outs = ["libral_base_context.so", "libdisc_custom_ops.so", "disc_compiler_main", "torch-mlir-opt"],
local = True,
cmd = ';'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}',
cmd = '&&'.join(['export PATH=/root/bin:/usr/local/cuda/bin:$${PATH}',
'pushd external/disc_compiler/pytorch_blade/',
'python ../scripts/python/common_setup.py',
'TF_CUDA_COMPUTE_CAPABILITIES="7.0,8.0,8.6,9.0" TORCH_CUDA_ARCH_LIST="7.0 8.0 8.6 9.0" python setup.py bdist_wheel',
Expand All @@ -46,4 +47,5 @@ genrule(
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so $(location libdisc_custom_ops.so)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/external/org_disc_compiler/mlir/disc/disc_compiler_main $(location disc_compiler_main)',
'cp third_party/BladeDISC/pytorch_blade/bazel-bin/tests/mhlo/torch-mlir-opt/torch-mlir-opt $(location torch-mlir-opt)']),
visibility = ["//visibility:public"],
)
7 changes: 4 additions & 3 deletions bazel/flash_attn.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ cc_import(

genrule(
name = "build_flash_attn",
srcs = ["setup.py"],
srcs = glob(["third_party/flash-attention/**"]),
outs = ["flash_attn_cuda.so"],
cmd = ';'.join(['pushd external/flash_attn/',
cmd = '&&'.join(['pushd external/flash_attn/',
'FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel',
'popd',
'cp external/flash_attn/build/*/*.so $(location flash_attn_cuda.so)']),
'cp external/flash_attn/build/*/*.so $(OUTS)']),
visibility = ["//visibility:public"],
)
35 changes: 20 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import contextlib
import distutils.ccompiler
import distutils.command.clean
import glob
import os
import requests
import shutil
Expand Down Expand Up @@ -245,26 +246,30 @@ def bazel_build(self, ext):
os.system(f"patchelf --add-rpath '$ORIGIN/' {ext_bazel_bin_path}")
shutil.copyfile(ext_bazel_bin_path, ext_dest_path)

def copyfiles(bazel_bin_path, file_name_list):
if not isinstance(file_name_list, list):
file_name_list = [file_name_list]

for file_name in file_name_list:
src = glob.glob(os.path.join(bazel_bin_path, file_name))
assert len(src) == 1
dest = '/'.join([ext_dest_dir, file_name])
shutil.copy2(src[0], dest)

# copy flash attention cuda so file
flash_attn_so_name = 'flash_attn_cuda.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/flash_attn/'
shutil.copyfile('/'.join([bazel_bin_path, flash_attn_so_name]),
'/'.join([ext_dest_dir, flash_attn_so_name]))
copyfiles(
'build/*/bazel-bin/external/flash_attn/',
['flash_attn_cuda.so'])

# package BladeDISC distribution files
# please note, TorchBlade also create some symbolic links to 'torch_blade' dir
if build_util.check_env_flag('ENABLE_DISC', 'false'):
disc_ral_so_name = 'libral_base_context.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler'
shutil.copyfile(
os.path.join(bazel_bin_path, disc_ral_so_name),
'/'.join([ext_dest_dir, disc_ral_so_name]))

disc_customop_so_name = 'libdisc_custom_ops.so'
bazel_bin_path = 'build/temp.linux-x86_64-cpython-310/bazel-bin/external/disc_compiler'
shutil.copyfile(
os.path.join(bazel_bin_path, disc_customop_so_name),
'/'.join([ext_dest_dir, disc_customop_so_name]))
copyfiles(
'build/*/bazel-bin/external/disc_compiler',
[
'libral_base_context.so', 'libdisc_custom_ops.so',
'disc_compiler_main', 'torch-mlir-opt'
])


class Develop(develop.develop):
Expand Down

0 comments on commit 8037f43

Please sign in to comment.