Skip to content

Commit 6e4c6a8

Browse files
authored
feat: build the shardy dialect (#622)
* feat: build the shardy dialect * fix: remove http_archive * fix: register dialect * fix: missing loadDialect
1 parent 29b0eac commit 6e4c6a8

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
3838
#include "src/enzyme_ad/jax/Passes/Passes.h"
3939
#include "llvm/Support/TargetSelect.h"
40+
#include "shardy/dialect/sdy/ir/dialect.h"
4041

4142
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
4243
#include "stablehlo/dialect/ChloOps.h"
@@ -685,6 +686,7 @@ extern "C" void RegisterDialects(MlirContext cctx) {
685686
context.loadDialect<mlir::mhlo::MhloDialect>();
686687
context.loadDialect<mlir::stablehlo::StablehloDialect>();
687688
context.loadDialect<mlir::chlo::ChloDialect>();
689+
context.loadDialect<mlir::sdy::SdyDialect>();
688690
}
689691

690692
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
@@ -713,6 +715,8 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
713715
mlir::registerNVVMDialectImport(registry);
714716
mlir::LLVM::registerInlinerInterface(registry);
715717

718+
mlir::sdy::registerAllDialects(registry);
719+
716720
/*
717721
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
718722
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);

deps/ReactantExtra/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,21 @@ gentbl_cc_library(
795795
tblgen = "//:mlir-jl-tblgen",
796796
)
797797

798+
gentbl_cc_library(
799+
name = "ShardyJLIncGen",
800+
tbl_outs = [(
801+
["--generator=jl-op-defs", "--disable-module-wrap=0"],
802+
"Shardy.jl"
803+
)
804+
],
805+
td_file = "@shardy//shardy/dialect/sdy/ir:ops.td",
806+
deps = [
807+
"@shardy//shardy/dialect/sdy/ir:sdy_td_files",
808+
],
809+
tblgen = "//:mlir-jl-tblgen",
810+
includes = ["external/shardy"],
811+
)
812+
798813
genrule(
799814
name = "libMLIR_h.jl",
800815
tags = [

deps/ReactantExtra/make-bindings.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
const bazel_cmd = if !isnothing(Sys.which("bazelisk"))
2+
"bazelisk"
3+
elseif !isnothing(Sys.which("bazel"))
4+
"bazel"
5+
else
6+
error("Could not find `bazel` or `bazelisk` in PATH!")
7+
end
8+
19
function build_file(output_path)
210
file = basename(output_path)
311
run(
412
Cmd(
5-
`bazel build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`;
13+
`$(bazel_cmd) build --action_env=JULIA=$(Base.julia_cmd().exec[1]) --action_env=JULIA_DEPOT_PATH=$(Base.DEPOT_PATH) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file`;
614
dir=@__DIR__,
715
),
816
)
@@ -29,6 +37,7 @@ for file in [
2937
"Affine.jl",
3038
"TPU.jl",
3139
"Triton.jl",
40+
"Shardy.jl",
3241
]
3342
build_file(joinpath(src_dir, "mlir", "Dialects", file))
3443
end

0 commit comments

Comments
 (0)