Skip to content

Commit

Permalink
continuing
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 14, 2024
1 parent 8782988 commit 95d45b6
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 7 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ XLA_SHA256 = ""
# )

local_repository(
name = "xlae",
name = "xla",
path = "./xla"
)

Expand Down
Binary file removed src/enzyme_ad/jax/.primitives.py.swp
Binary file not shown.
67 changes: 66 additions & 1 deletion src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
load("@llvm-project//llvm:tblgen.bzl", "gentbl")

licenses(["notice"])

Expand Down Expand Up @@ -39,11 +40,75 @@ py_library(
visibility = ["//visibility:public"]
)


gentbl(
name = "mhlo-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Implementations/MHLODerivatives.inc",
)],
tblgen = "@enzyme//:enzyme-tblgen",
td_file = "Implementations/MHLODerivatives.td",
td_srcs = [
"Implementations/MHLODerivatives.td",
"Implementations/Common.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
],
)

td_library(
name = "EnzymeXLAPassesTdFiles",
srcs = [
],
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "EnzymeXLAPassesIncGen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=enzymexla",
],
"Passes/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes/Passes.td",
deps = [":EnzymeXLAPassesTdFiles"],
)

cc_library(
name = "XLADerivatives",
srcs = glob(
[
"Implementations/*.cpp",
"Passes/*.cpp",
],
),
hdrs = glob([
"Implementations/*.h",
"Passes/*.h",
]),
deps = [
":EnzymeXLAPassesIncGen",
":mhlo-derivatives",
"@xla//xla/mlir_hlo",
"@enzyme//:EnzymeMLIR",
]
)

pybind_library(
name = "compile_with_xla",
srcs = ["compile_with_xla.cc"],
hdrs = ["compile_with_xla.h"],
hdrs = glob(["compile_with_xla.h", "Implementations/*.h", "Passes/*.h"]),
deps = [
":XLADerivatives",
# This is similar to xla_binary rule and is needed to make XLA client compile.
"@tsl//tsl/framework:allocator",
"@tsl//tsl/framework:allocator_registry_impl",
Expand Down
69 changes: 69 additions & 0 deletions src/enzyme_ad/jax/Implementations/Common.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
class InactiveOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class AllocationOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class ControlFlowOp<string dialect_, string opName_, string impl_> {
string dialect = dialect_;
string opName = opName_;
string impl = impl_;
}

class MemoryIdentityOp<string dialect_, string opName_, list<int> ptrargs_, list<int> storedargs_ = []> {
string dialect = dialect_;
string opName = opName_;
list<int> ptrargs = ptrargs_;
list<int> storedargs = storedargs_;
}

class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> ptrargs_> : MemoryIdentityOp<dialect_, opName_, ptrargs_>;

class BranchOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class RegionTerminatorOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps> {
string dialect = dialect_;
string opName = opName_;
dag PatternToMatch = patternToMatch;
list<dag> ArgDerivatives = resultOps;
}

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesPrimal = usesPrimal_;
bit usesShadow = usesShadow_;
bit usesCustom = usesCustom_;
}

class DiffeRetIndex<list<int> indices_> {
list<int> indices = indices_;
}
def DiffeRet : DiffeRetIndex<[-1]>;

class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
}

def Op {
}

class MHLOInst<string m> : Inst<m, "mhlo">;

def Add : MHLOInst<"mhlo::AddOp">;
def Sub : MHLOInst<"mhlo::SubOp">;
def Neg : MHLOInst<"mhlo::NegOp">;
def Mul : MHLOInst<"mhlo::MulOp">;
def Div : MHLOInst<"mhlo::DivOp">;
def Rem : MHLOInst<"mhlo::RemOp">;
74 changes: 74 additions & 0 deletions src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
//
//===----------------------------------------------------------------------===//

#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

#include "Dialect/Ops.h"
#include "mlir/IR/TypeSupport.h"

#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
#include "src/enzyme_ad/jax/Implementations/MHLODerivatives.inc"

#ifndef DISABLE_MHLO_TENSOR_INTERFACE
class MHLOTensorTypeInterface
: public AutoDiffTypeInterface::ExternalModel<MHLOTensorTypeInterface,
TensorType> {
public:
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
auto tenType = self.cast<TensorType>();
auto attr = DenseElementsAttr::get(tenType, 0);
return builder.create<mhlo::ConstantOp>(loc, tenType, attr);
}

Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
Value b) const {
return builder.create<mhlo::AddOp>(loc, a, b);
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
}

bool requiresShadow(Type self) const { return false; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
}
};
#endif
} // namespace

void mlir::enzyme::registerMHLODialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, mhlo::MhloDialect *) {
registerInterfaces(context);

#ifndef DISABLE_MHLO_TENSOR_INTERFACE
UnrankedTensorType::attachInterface<MHLOTensorTypeInterface>(*context);
RankedTensorType::attachInterface<MHLOTensorTypeInterface>(*context);
#endif
});
}
8 changes: 8 additions & 0 deletions src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
include "Common.td"

def : MLIRDerivative<"mhlo", "AddOp", (Op $x, $y),
[
(DiffeRet),
(DiffeRet),
]
>;
11 changes: 11 additions & 0 deletions src/enzyme_ad/jax/Implementations/XLADerivatives.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "mlir/IR/DialectRegistry.h"

namespace mlir {
namespace enzyme {
void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);

static inline void registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
registerMHLODialectAutoDiffInterface(registry);
}
}
}
13 changes: 10 additions & 3 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,20 @@
#include "compile_with_xla.h"

#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Implementations/XLADerivatives.h"

void prepareRegistry(mlir::DialectRegistry &registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
mlir::enzyme::registerXLAAutoDiffInterfaces(registry);
}
// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
bool xla_runtime,
const std::string &pass_pipeline) {
// Parse MLIR.
mlir::DialectRegistry registry;
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
prepareRegistry(registry);
mlir::MLIRContext context(registry);
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::func::FuncDialect>();
Expand Down Expand Up @@ -132,11 +137,13 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
throw pybind11::value_error(executor.status().ToString());
}

xla::Compiler::CompileOptions opts = {build_options.device_allocator(), build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback()};
opts.registry = &registry;
auto executable = local_client->local_service()->BuildExecutable(
xla_computation.proto(), std::move(module_config_or_error.value()),
local_client->mutable_backend(), executor.value(),
{build_options.device_allocator(), build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback(), &registry},
opts,
build_options.run_backend_only());
if (!executable.ok()) {
throw pybind11::value_error(executable.status().ToString());
Expand Down
4 changes: 2 additions & 2 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ def make_zero(tan, prim):

shadconv = None
if pipeline_options.mlir_ad():
act_tup = (",".join(["enzyme_dup" for a in args]))
newpasses = "enzyme-wrap{infn=main outfn=main retTy=enzyme_dup argTys="+act_tup+" mode=ForwardMode}," + pipeline_options.pass_pipeline()
act_tup = (",".join(["enzyme_dup" for a in arg_primals]))
newpasses = "enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys="+act_tup+" mode=ForwardMode}," + pipeline_options.pass_pipeline()
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())
outshapes2 = []
for o in kwargs["out_shapes"]:
Expand Down

0 comments on commit 95d45b6

Please sign in to comment.