diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 4379f2e5f..6f32d78e8 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -112,6 +112,7 @@ pybind_extension( "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcTargetProcess", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", ":clang_compile", ":compile_with_xla", "@com_google_absl//absl/status:statusor", diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index eef2c8baa..224ef5314 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -33,7 +33,8 @@ // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, - bool xla_runtime) { + bool xla_runtime, + const std::string &pass_pipeline) { // Parse MLIR. mlir::MLIRContext context; context.loadDialect(); @@ -103,6 +104,10 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, build_options.mutable_debug_options()->set_xla_cpu_use_xla_runtime( xla_runtime); + build_options.mutable_debug_options() + ->mutable_xla_backend_extra_options() + ->emplace("xla_cpu_experimental_override_pipeline", pass_pipeline); + if (build_options.device_ordinal() == -1) { build_options.set_device_ordinal(local_client->default_device_ordinal()); } diff --git a/src/enzyme_ad/jax/compile_with_xla.h b/src/enzyme_ad/jax/compile_with_xla.h index 106ed918d..a2e8f3ac0 100644 --- a/src/enzyme_ad/jax/compile_with_xla.h +++ b/src/enzyme_ad/jax/compile_with_xla.h @@ -5,4 +5,5 @@ // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, - bool xla_runtime); + bool xla_runtime, + const std::string &pass_pipeline); diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 07a09dacb..e25c9b615 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -38,6 +38,8 @@ #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/InitAllPasses.h" + #include "compile_with_xla.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -81,7 +83,8 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, - ABI mode, Language lang, bool xla_runtime) { + ABI mode, Language lang, bool xla_runtime, + const std::string &pass_pipeline) { auto llvm_ctx = std::make_unique(); std::string input; @@ -102,8 +105,8 @@ class CpuKernel { break; case Language::MHLO: { - local_executable = - compile_mhlo_to_llvm_with_xla(source, stringbuf, xla_runtime); + local_executable = compile_mhlo_to_llvm_with_xla( + source, stringbuf, xla_runtime, pass_pipeline); auto *cpu_executable = static_cast( local_executable->executable()); auto &assignment = cpu_executable->buffer_assignment(); @@ -830,11 +833,12 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, - Language lang, bool xla_runtime) { + Language lang, bool xla_runtime, + const std::string &pass_pipeline) { auto mode = ABI::Tape; auto [mod, llvm_ctx, num_out, tmpBuf] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, - pyargv, mode, lang, xla_runtime); + pyargv, mode, lang, xla_runtime, pass_pipeline); auto lfn = mod->getFunction("entry"); auto RI = llvm::cast(lfn->getEntryBlock().getTerminator()); @@ -846,12 +850,12 @@ class CpuKernel { } static size_t tempSize(llvm::StringRef source, Language lang, - bool xla_runtime) { + bool xla_runtime, const std::string &pass_pipeline) { switch (lang) { case Language::MHLO: { std::string llvm_ir; - auto local_executable = - compile_mhlo_to_llvm_with_xla(source, llvm_ir, xla_runtime); + auto local_executable = compile_mhlo_to_llvm_with_xla( + source, llvm_ir, xla_runtime, pass_pipeline); auto *cpu_executable = static_cast( local_executable->executable()); auto &assignment = cpu_executable->buffer_assignment(); @@ -868,13 +872,13 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, ABI mode, - Language lang, bool xla_runtime) { + Language lang, bool xla_runtime, const std::string &pass_pipeline) { llvm::sys::SmartScopedWriter lock(kernel_mutex); size_t identifier = last_identifier++; auto [mod, llvm_ctx, num_out, tmpBuf] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, - pyargv, mode, lang, xla_runtime); + pyargv, mode, lang, xla_runtime, pass_pipeline); if (!JIT) { DL = std::make_unique(mod.get()); @@ -986,6 +990,8 @@ PYBIND11_MODULE(enzyme_call, m) { llvm::InitializeAllAsmParsers(); EnzymeAlwaysInlineDiff.setValue(true); + mlir::registerAllPasses(); + pybind11::enum_(m, "Language") .value("CPP", Language::CPP) .value("LLVM", Language::LLVM) @@ -1002,8 +1008,8 @@ PYBIND11_MODULE(enzyme_call, m) { [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, const pybind11::list &py_in_shapes, pybind11::object pyargv, - ABI mode, Language lang, - bool xla_runtime) -> std::tuple { + ABI mode, Language lang, bool xla_runtime, + const std::string &pass_pipeline) -> std::tuple { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -1039,20 +1045,22 @@ PYBIND11_MODULE(enzyme_call, m) { } return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), mode, (Language)lang, - xla_runtime); + xla_runtime, pass_pipeline); }); - m.def( - "tmp_size", - [](const std::string &source, Language lang, bool xla_runtime) -> size_t { - return CpuKernel::tempSize(source, (Language)lang, xla_runtime); - }); + m.def("tmp_size", + [](const std::string &source, Language lang, bool xla_runtime, + const std::string &pass_pipeline) -> size_t { + return CpuKernel::tempSize(source, (Language)lang, xla_runtime, + pass_pipeline); + }); m.def("tape_and_tmp_size", [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, const pybind11::list &py_in_shapes, pybind11::object pyargv, - Language lang, bool xla_runtime) -> std::pair { + Language lang, bool xla_runtime, + const std::string &pass_pipeline) -> std::pair { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -1086,9 +1094,9 @@ PYBIND11_MODULE(enzyme_call, m) { target.push_back(nested_element.cast()); } } - return CpuKernel::tapeAndTempSize(fn, source, out_shapes, out_types, - in_shapes, in_types, pyargv.ptr(), - (Language)lang, xla_runtime); + return CpuKernel::tapeAndTempSize( + fn, source, out_shapes, out_types, in_shapes, in_types, + pyargv.ptr(), (Language)lang, xla_runtime, pass_pipeline); }); m.def("get_cpu_callback", []() { @@ -1097,9 +1105,11 @@ PYBIND11_MODULE(enzyme_call, m) { }); m.def("compile_mhlo_to_llvm_with_xla", - [](const std::string &mhlo_text, bool xla_runtime) { + [](const std::string &mhlo_text, bool xla_runtime, + const std::string &pass_pipeline) { std::string llvm_ir; - compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir, xla_runtime); + compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir, xla_runtime, + pass_pipeline); return llvm_ir; }); } diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 416a1be89..7c94ba7c6 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -22,8 +22,11 @@ LANG_LLVM = enzyme_call.Language.LLVM LANG_MHLO = enzyme_call.Language.MHLO -xla_runtime = True +def xla_runtime(options): + return True +def pass_pipeline(options): + return "any(inline{default-pipeline=canonicalize max-iterations=4 },expand-hlo-tuples{entry-function=main},func.func(mhlo-flatten-tuple),xla-legalize-abi,func.func(mhlo-test-lower-general-dot),func.func(mhlo-broadcast-propagation),cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(xla-sparse-custom-call-to-pack),func.func(legalize-sparse-ops{legalize-to-custom-calls=false}),func.func(chlo-legalize-to-hlo{expand-compositions=true legalize-broadcasts=true}),func.func(mhlo-sparse-rewriting),func.func(mhlo-legalize-control-flow),func.func(mhlo-legalize-dot-general-to-dot),hlo-legalize-to-arithmetic,func.func(xla-legalize-library-ops),func.func(mhlo-expand-ops-simplifier),func.func(hlo-canonicalize-scatter),func.func(hlo-canonicalize-dot),func.func(group-reduction-dimensions{prefer-columns-reductions=true}),func.func(hlo-legalize-to-linalg{enable-primitive-ops=false}),func.func(lower-index-cast),convert-to-signless,func.func(shape-simplification),func.func(shape-to-shape-lowering),convert-shape-to-std,func.func(convert-shape-constraints),cse,resolve-shaped-type-result-dims,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(linalg-fuse-elementwise-ops),reconcile-unrealized-casts,convert-tensor-to-linalg,func.func(detensorize-scf-ops),func.func(linalg-detensorize{aggressive-mode=true}),eliminate-empty-tensors,func.func(empty-tensor-to-alloc-tensor),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(linalg-generalize-named-ops),eliminate-empty-tensors,sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}),func.func(finalizing-bufferize),func.func(xla-rewrite-realloc-to-alloc),func.func(vectorize-copy),func.func(naive-copy-removal),func.func(convert-linalg-to-loops),cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},buffer-results-to-out-params,func.func(promote-buffers-to-stack{max-alloc-size-in-bytes=1024 max-rank-of-allocated-memref=1}),func.func(buffer-deallocation),convert-bufferization-to-memref,func.func(xla-remove-copies-to-out-params),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(convert-complex-to-standard),cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}),func.func(xla-legalize-i1-vector-transfers),func.func(xla-convert-memref-element-cast-to-llvm),async-func-to-async-runtime,xla-rt-export-functions,xla-cpu-to-cpu-runtime,xla-rt-convert-custom-calls,xla-rt-convert-asserts,inline{default-pipeline=canonicalize max-iterations=4 },canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse,func.func(xla-math-approximation{oplist=all}),func.func(convert-linalg-to-parallel-loops),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},async-to-async-runtime,xla-rt-move-allocas-to-entry-block,async-runtime-policy-based-ref-counting,func.func(arith-expand{include-bf16=false}),func.func(memref-expand),func.func(expand-strided-metadata),lower-affine,func.func(xla-memref-aligned-allocations{alignment=0}),xla-rt-to-llvm,convert-async-to-llvm,generic-host-to-llvm{enable-avx2=false},reconcile-unrealized-casts,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse)" def resource_dir(): import os @@ -69,6 +72,7 @@ def _enzyme_primal_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -81,6 +85,7 @@ def _enzyme_fwd_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -93,6 +98,7 @@ def _enzyme_aug_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -105,6 +111,7 @@ def _enzyme_shadow_aug_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -117,6 +124,7 @@ def _enzyme_rev_impl( argv: Sequence[str], in_shapes, lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -129,6 +137,7 @@ def _enzyme_primal_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: # TODO: we may attempt some lightweight parsing of source to extract the # result types instead. @@ -142,6 +151,7 @@ def _enzyme_fwd_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: del source, fn, args_flat return tuple(o for o in out_shapes for _ in range(2)) @@ -160,6 +170,7 @@ def _enzyme_aug_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: in_shapes = args_flat @@ -181,7 +192,7 @@ def _enzyme_aug_abstract_eval( argv = argv + ("-resource-dir", resource_dir()) + cflags() tapeSize, tmpSize = enzyme_call.tape_and_tmp_size( - source, fn, out_shapes, in_shapes, argv, lang, xla_runtime + source, fn, out_shapes, in_shapes, argv, lang, xla_runtime(pipeline_options), pass_pipeline(pipeline_options) ) res = tuple(prev_out_shapes) + ( jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)), @@ -196,6 +207,7 @@ def _enzyme_shadow_aug_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: return out_shapes @@ -207,6 +219,7 @@ def _enzyme_rev_abstract_eval( argv: Sequence[str], in_shapes, lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: return tuple( jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes @@ -233,6 +246,7 @@ def _enzyme_primal_lowering( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del out_shapes @@ -262,7 +276,8 @@ def _enzyme_primal_lowering( argv, enzyme_call.ABI.Primal, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options) ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -293,6 +308,7 @@ def _enzyme_fwd_lowering( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del out_shapes @@ -323,7 +339,8 @@ def _enzyme_fwd_lowering( argv, enzyme_call.ABI.Forward, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options) ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -353,6 +370,7 @@ def _enzyme_aug_lowering( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del out_shapes @@ -383,7 +401,8 @@ def _enzyme_aug_lowering( argv, enzyme_call.ABI.Augmented, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options) ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -411,6 +430,7 @@ def _enzyme_rev_lowering( argv: Sequence[str], in_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del in_shapes @@ -450,7 +470,8 @@ def _enzyme_rev_lowering( argv, enzyme_call.ABI.Reverse, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options) ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -497,9 +518,10 @@ def ffi_call( fn: str = "f", argv: tuple[str] = (), lang: int = LANG_CPP, + pipeline_options = None ): return _enzyme_primal_p.bind( - *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang + *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang, pipeline_options=pipeline_options ) @@ -509,9 +531,10 @@ def cpp_call( source: str, fn: str = "f", argv: tuple[str] = (), + pipeline_options = None ): return ffi_call( - *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP + *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP, pipeline_options=pipeline_options ) @@ -550,6 +573,7 @@ def make_zero(tan, prim): argv=kwargs["argv"], out_shapes=kwargs["out_shapes"], lang=kwargs["lang"], + pipeline_options=kwargs["pipeline_options"] ) res = (shadconv[0::2], shadconv[1::2]) return res @@ -640,7 +664,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): ad.primitive_transposes[_enzyme_shadow_aug_p] = enzyme_vjp -def enzyme_jax_ir(argv=()): +def enzyme_jax_ir(argv=(), pipeline_options=None): def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @jax.jit def wrapped(*args: Any): @@ -657,6 +681,7 @@ def wrapped(*args: Any): out_shapes=out_shape_flat, argv=argv, lang=LANG_MHLO, + pipeline_options=pipeline_options ) return jax.tree_util.tree_unflatten(out_tree, out_flat)