File tree 3 files changed +29
-1
lines changed
3 files changed +29
-1
lines changed Original file line number Diff line number Diff line change 37
37
#include " src/enzyme_ad/jax/Implementations/XLADerivatives.h"
38
38
#include " src/enzyme_ad/jax/Passes/Passes.h"
39
39
#include " llvm/Support/TargetSelect.h"
40
+ #include " shardy/dialect/sdy/ir/dialect.h"
40
41
41
42
#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
42
43
#include " stablehlo/dialect/ChloOps.h"
@@ -685,6 +686,7 @@ extern "C" void RegisterDialects(MlirContext cctx) {
685
686
context.loadDialect <mlir::mhlo::MhloDialect>();
686
687
context.loadDialect <mlir::stablehlo::StablehloDialect>();
687
688
context.loadDialect <mlir::chlo::ChloDialect>();
689
+ context.loadDialect <mlir::sdy::SdyDialect>();
688
690
}
689
691
690
692
#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
@@ -713,6 +715,8 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
713
715
mlir::registerNVVMDialectImport (registry);
714
716
mlir::LLVM::registerInlinerInterface (registry);
715
717
718
+ mlir::sdy::registerAllDialects (registry);
719
+
716
720
/*
717
721
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
718
722
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Original file line number Diff line number Diff line change @@ -795,6 +795,21 @@ gentbl_cc_library(
795
795
tblgen = "//:mlir-jl-tblgen" ,
796
796
)
797
797
798
+ gentbl_cc_library (
799
+ name = "ShardyJLIncGen" ,
800
+ tbl_outs = [(
801
+ ["--generator=jl-op-defs" , "--disable-module-wrap=0" ],
802
+ "Shardy.jl"
803
+ )
804
+ ],
805
+ td_file = "@shardy//shardy/dialect/sdy/ir:ops.td" ,
806
+ deps = [
807
+ "@shardy//shardy/dialect/sdy/ir:sdy_td_files" ,
808
+ ],
809
+ tblgen = "//:mlir-jl-tblgen" ,
810
+ includes = ["external/shardy" ],
811
+ )
812
+
798
813
genrule (
799
814
name = "libMLIR_h.jl" ,
800
815
tags = [
Original file line number Diff line number Diff line change
1
+ const bazel_cmd = if ! isnothing (Sys. which (" bazelisk" ))
2
+ " bazelisk"
3
+ elseif ! isnothing (Sys. which (" bazel" ))
4
+ " bazel"
5
+ else
6
+ error (" Could not find `bazel` or `bazelisk` in PATH!" )
7
+ end
8
+
1
9
function build_file (output_path)
2
10
file = basename (output_path)
3
11
run (
4
12
Cmd (
5
- ` bazel build --action_env=JULIA=$(Base. julia_cmd (). exec[1 ]) --action_env=JULIA_DEPOT_PATH=$(Base. DEPOT_PATH ) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file ` ;
13
+ ` $(bazel_cmd) build --action_env=JULIA=$(Base. julia_cmd (). exec[1 ]) --action_env=JULIA_DEPOT_PATH=$(Base. DEPOT_PATH ) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures //:$file ` ;
6
14
dir= @__DIR__ ,
7
15
),
8
16
)
@@ -29,6 +37,7 @@ for file in [
29
37
" Affine.jl" ,
30
38
" TPU.jl" ,
31
39
" Triton.jl" ,
40
+ " Shardy.jl" ,
32
41
]
33
42
build_file (joinpath (src_dir, " mlir" , " Dialects" , file))
34
43
end
You can’t perform that action at this time.
0 commit comments