diff --git a/WORKSPACE b/WORKSPACE index 08c33578b..4d6878ea6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,7 +43,7 @@ XLA_SHA256 = "" # ) local_repository( - name = "xlae", + name = "xla", path = "./xla" ) diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index f5b4d0a4b..8b6909540 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -1,4 +1,5 @@ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") +load("@llvm-project//llvm:tblgen.bzl", "gentbl") licenses(["notice"]) @@ -39,11 +40,75 @@ py_library( visibility = ["//visibility:public"] ) + +gentbl( + name = "mhlo-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Implementations/MHLODerivatives.inc", + )], + tblgen = "@enzyme//:enzyme-tblgen", + td_file = "Implementations/MHLODerivatives.td", + td_srcs = [ + "Implementations/MHLODerivatives.td", + "Implementations/Common.td", + ], + deps = [ + "@enzyme//:enzyme-tblgen", + ], +) + +td_library( + name = "EnzymeXLAPassesTdFiles", + srcs = [ + ], + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "EnzymeXLAPassesIncGen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=enzymexla", + ], + "Passes/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes/Passes.td", + deps = [":EnzymeXLAPassesTdFiles"], +) + +cc_library( + name = "XLADerivatives", + srcs = glob( + [ + "Implementations/*.cpp", + "Passes/*.cpp", + ], + ), + hdrs = glob([ + "Implementations/*.h", + "Passes/*.h", + ]), + deps = [ + ":EnzymeXLAPassesIncGen", + ":mhlo-derivatives", + "@xla//xla/mlir_hlo", + "@enzyme//:EnzymeMLIR", + ] +) + pybind_library( name = "compile_with_xla", srcs = ["compile_with_xla.cc"], - hdrs = ["compile_with_xla.h"], + hdrs = glob(["compile_with_xla.h", "Implementations/*.h", "Passes/*.h"]), deps = [ + ":XLADerivatives", # This is similar to xla_binary rule and is needed to make XLA client compile. "@tsl//tsl/framework:allocator", "@tsl//tsl/framework:allocator_registry_impl", diff --git a/src/enzyme_ad/jax/Implementations/Common.td b/src/enzyme_ad/jax/Implementations/Common.td new file mode 100644 index 000000000..55197f10c --- /dev/null +++ b/src/enzyme_ad/jax/Implementations/Common.td @@ -0,0 +1,69 @@ +class InactiveOp { + string dialect = dialect_; + string opName = opName_; +} + +class AllocationOp { + string dialect = dialect_; + string opName = opName_; +} + +class ControlFlowOp { + string dialect = dialect_; + string opName = opName_; + string impl = impl_; +} + +class MemoryIdentityOp ptrargs_, list storedargs_ = []> { + string dialect = dialect_; + string opName = opName_; + list ptrargs = ptrargs_; + list storedargs = storedargs_; +} + +class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; + +class BranchOp { + string dialect = dialect_; + string opName = opName_; +} + +class RegionTerminatorOp { + string dialect = dialect_; + string opName = opName_; +} + +class MLIRDerivative resultOps> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ArgDerivatives = resultOps; +} + +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} + +class DiffeRetIndex indices_> { + list indices = indices_; +} +def DiffeRet : DiffeRetIndex<[-1]>; + +class Inst : Operation { + string name = mnemonic; + string dialect = dialect_; +} + +def Op { +} + +class MHLOInst : Inst; + +def Add : MHLOInst<"mhlo::AddOp">; +def Sub : MHLOInst<"mhlo::SubOp">; +def Neg : MHLOInst<"mhlo::NegOp">; +def Mul : MHLOInst<"mhlo::MulOp">; +def Div : MHLOInst<"mhlo::DivOp">; +def Rem : MHLOInst<"mhlo::RemOp">; diff --git a/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000..1645a8e3c --- /dev/null +++ b/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,74 @@ +//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h" +#include "Enzyme/MLIR/Interfaces/GradientUtils.h" +#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "src/enzyme_ad/jax/Implementations/MHLODerivatives.inc" + +#ifndef DISABLE_MHLO_TENSOR_INTERFACE +class MHLOTensorTypeInterface + : public AutoDiffTypeInterface::ExternalModel { +public: + Value createNullValue(Type self, OpBuilder &builder, Location loc) const { + auto tenType = self.cast(); + auto attr = DenseElementsAttr::get(tenType, 0); + return builder.create(loc, tenType, attr); + } + + Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, + Value b) const { + return builder.create(loc, a, b); + } + + Type getShadowType(Type self, unsigned width) const { + assert(width == 1 && "unsupported width != 1"); + return self; + } + + bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } +}; +#endif +} // namespace + +void mlir::enzyme::registerMHLODialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, mhlo::MhloDialect *) { + registerInterfaces(context); + +#ifndef DISABLE_MHLO_TENSOR_INTERFACE + UnrankedTensorType::attachInterface(*context); + RankedTensorType::attachInterface(*context); +#endif + }); +} diff --git a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td new file mode 100644 index 000000000..ca4d46aab --- /dev/null +++ b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td @@ -0,0 +1,8 @@ +include "Common.td" + +def : MLIRDerivative<"mhlo", "AddOp", (Op $x, $y), + [ + (DiffeRet), + (DiffeRet), + ] + >; diff --git a/src/enzyme_ad/jax/Implementations/XLADerivatives.h b/src/enzyme_ad/jax/Implementations/XLADerivatives.h new file mode 100644 index 000000000..6741533b5 --- /dev/null +++ b/src/enzyme_ad/jax/Implementations/XLADerivatives.h @@ -0,0 +1,11 @@ +#include "mlir/IR/DialectRegistry.h" + +namespace mlir { +namespace enzyme { +void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); + +static inline void registerXLAAutoDiffInterfaces(mlir::DialectRegistry ®istry) { + registerMHLODialectAutoDiffInterface(registry); +} +} +} diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index 0052863db..d6a17ccd5 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -31,7 +31,12 @@ #include "compile_with_xla.h" #include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Implementations/XLADerivatives.h" +void prepareRegistry(mlir::DialectRegistry ®istry) { + mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); + mlir::enzyme::registerXLAAutoDiffInterfaces(registry); +} // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, @@ -39,7 +44,7 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, const std::string &pass_pipeline) { // Parse MLIR. mlir::DialectRegistry registry; - mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); + prepareRegistry(registry); mlir::MLIRContext context(registry); context.loadDialect(); context.loadDialect(); @@ -132,11 +137,13 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, throw pybind11::value_error(executor.status().ToString()); } + xla::Compiler::CompileOptions opts = {build_options.device_allocator(), build_options.compile_thread_pool(), + build_options.layout_canonicalization_callback()}; + opts.registry = ®istry; auto executable = local_client->local_service()->BuildExecutable( xla_computation.proto(), std::move(module_config_or_error.value()), local_client->mutable_backend(), executor.value(), - {build_options.device_allocator(), build_options.compile_thread_pool(), - build_options.layout_canonicalization_callback(), ®istry}, + opts, build_options.run_backend_only()); if (!executable.ok()) { throw pybind11::value_error(executable.status().ToString()); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 66f1ab754..165a9ddd5 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -796,8 +796,8 @@ def make_zero(tan, prim): shadconv = None if pipeline_options.mlir_ad(): - act_tup = (",".join(["enzyme_dup" for a in args])) - newpasses = "enzyme-wrap{infn=main outfn=main retTy=enzyme_dup argTys="+act_tup+" mode=ForwardMode}," + pipeline_options.pass_pipeline() + act_tup = (",".join(["enzyme_dup" for a in arg_primals])) + newpasses = "enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys="+act_tup+" mode=ForwardMode}," + pipeline_options.pass_pipeline() pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad()) outshapes2 = [] for o in kwargs["out_shapes"]: