Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 15, 2024
1 parent 9fbdc65 commit 38e23d5
Show file tree
Hide file tree
Showing 18 changed files with 189 additions and 67 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "b74b7e9e8f67198c2b96724a1d042714ff6b2277"
ENZYME_COMMIT = "1b0a46dd751f204e1bbef8cc0641e3a1ae27c74f"
ENZYME_SHA256 = ""

http_archive(
Expand Down
21 changes: 21 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ gentbl(
td_srcs = [
"Implementations/MHLODerivatives.td",
"Implementations/Common.td",
"Implementations/HLODerivatives.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
],
)

gentbl(
name = "stablehlo-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Implementations/StableHLODerivatives.inc",
)],
tblgen = "@enzyme//:enzyme-tblgen",
td_file = "Implementations/StableHLODerivatives.td",
td_srcs = [
"Implementations/StableHLODerivatives.td",
"Implementations/Common.td",
"Implementations/HLODerivatives.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
Expand Down Expand Up @@ -99,6 +118,8 @@ cc_library(
deps = [
":EnzymeXLAPassesIncGen",
":mhlo-derivatives",
":stablehlo-derivatives",
"@stablehlo//:stablehlo_ops",
"@xla//xla/mlir_hlo",
"@enzyme//:EnzymeMLIR",
]
Expand Down
9 changes: 0 additions & 9 deletions src/enzyme_ad/jax/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,3 @@ class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*

def Op {
}

class MHLOInst<string m> : Inst<m, "mhlo">;

def Add : MHLOInst<"mhlo::AddOp">;
def Sub : MHLOInst<"mhlo::SubOp">;
def Neg : MHLOInst<"mhlo::NegOp">;
def Mul : MHLOInst<"mhlo::MulOp">;
def Div : MHLOInst<"mhlo::DivOp">;
def Rem : MHLOInst<"mhlo::RemOp">;
8 changes: 8 additions & 0 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


def : HLODerivative<"AddOp", (Op $x, $y),
[
(DiffeRet),
(DiffeRet),
]
>;
34 changes: 0 additions & 34 deletions src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,45 +30,11 @@ using namespace mlir::enzyme;

namespace {
#include "src/enzyme_ad/jax/Implementations/MHLODerivatives.inc"

#ifndef DISABLE_MHLO_TENSOR_INTERFACE
class MHLOTensorTypeInterface
: public AutoDiffTypeInterface::ExternalModel<MHLOTensorTypeInterface,
TensorType> {
public:
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
auto tenType = self.cast<TensorType>();
auto attr = DenseElementsAttr::get(tenType, 0);
return builder.create<mhlo::ConstantOp>(loc, tenType, attr);
}

Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
Value b) const {
return builder.create<mhlo::AddOp>(loc, a, b);
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
}

bool requiresShadow(Type self) const { return false; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
}
};
#endif
} // namespace

void mlir::enzyme::registerMHLODialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, mhlo::MhloDialect *) {
registerInterfaces(context);

#ifndef DISABLE_MHLO_TENSOR_INTERFACE
UnrankedTensorType::attachInterface<MHLOTensorTypeInterface>(*context);
RankedTensorType::attachInterface<MHLOTensorTypeInterface>(*context);
#endif
});
}
18 changes: 12 additions & 6 deletions src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
include "Common.td"

def : MLIRDerivative<"mhlo", "AddOp", (Op $x, $y),
[
(DiffeRet),
(DiffeRet),
]
>;
class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps>;

class HLOInst<string m> : Inst<m, "mhlo">;

def Add : HLOInst<"AddOp">;
def Sub : HLOInst<"mhlo::SubOp">;
def Neg : HLOInst<"mhlo::NegOp">;
def Mul : HLOInst<"mhlo::MulOp">;
def Div : HLOInst<"mhlo::DivOp">;
def Rem : HLOInst<"mhlo::RemOp">;

include "HLODerivatives.td"
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
//
//===----------------------------------------------------------------------===//

#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

#include "Dialect/Ops.h"
#include "mlir/IR/TypeSupport.h"

#include "stablehlo/dialect/StablehloOps.h"

#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
#include "src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc"
} // namespace

void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, stablehlo::StablehloDialect *) {
registerInterfaces(context);
});
}
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/Implementations/StableHLODerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
include "Common.td"

class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps>;

class HLOInst<string m> : Inst<m, "mhlo">;

def Add : HLOInst<"AddOp">;
def Sub : HLOInst<"mhlo::SubOp">;
def Neg : HLOInst<"mhlo::NegOp">;
def Mul : HLOInst<"mhlo::MulOp">;
def Div : HLOInst<"mhlo::DivOp">;
def Rem : HLOInst<"mhlo::RemOp">;

include "HLODerivatives.td"
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Implementations/XLADerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
namespace mlir {
namespace enzyme {
void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);

static inline void
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
registerMHLODialectAutoDiffInterface(registry);
registerStableHLODialectAutoDiffInterface(registry);
}
} // namespace enzyme
} // namespace mlir
20 changes: 16 additions & 4 deletions src/enzyme_ad/jax/Passes/ArithRaising.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

#include "stablehlo/dialect/StablehloOps.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"

#define DEBUG_TYPE "enzyme"
Expand All @@ -33,16 +35,26 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
void runOnOperation() override {
auto op = getOperation();

op->walk([](arith::AddFOp addOp) {
op->walk([=](arith::AddFOp addOp) {
OpBuilder builder(addOp);
Value newAddOp = builder.create<mhlo::AddOp>(
Value newAddOp;
if (use_stablehlo)
newAddOp = builder.create<stablehlo::AddOp>(
addOp.getLoc(), addOp->getOperand(0), addOp->getOperand(1));
else
newAddOp = builder.create<mhlo::AddOp>(
addOp.getLoc(), addOp->getOperand(0), addOp->getOperand(1));
addOp.replaceAllUsesWith(newAddOp);
addOp.erase();
});
op->walk([](arith::AddIOp addOp) {
op->walk([=](arith::AddIOp addOp) {
OpBuilder builder(addOp);
Value newAddOp = builder.create<mhlo::AddOp>(
Value newAddOp;
if (use_stablehlo)
newAddOp = builder.create<stablehlo::AddOp>(
addOp.getLoc(), addOp->getOperand(0), addOp->getOperand(1));
else
newAddOp = builder.create<mhlo::AddOp>(
addOp.getLoc(), addOp->getOperand(0), addOp->getOperand(1));
addOp.replaceAllUsesWith(newAddOp);
addOp.erase();
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ namespace mhlo {
class MhloDialect;
} // end namespace mhlo

namespace stablehlo {
class StablehloDialect;
} // end namespace arith

namespace arith {
class ArithDialect;
} // end namespace arith
Expand Down
12 changes: 11 additions & 1 deletion src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@ def ArithRaisingPass : Pass<"arith-raise"> {
let summary = "Raise Arith to mhlo";
let dependentDialects = [
"arith::ArithDialect",
"mhlo::MhloDialect"
"mhlo::MhloDialect",
"stablehlo::StablehloDialect"
];
let constructor = "mlir::enzyme::createArithRaisingPass()";
let options = [
Option<
/*C++ variable name=*/"use_stablehlo",
/*CLI argument=*/"stablehlo",
/*type=*/"bool",
/*default=*/"true",
/*description=*/"Whether to raise to stablehlo vs mhlo"
>
];
}

def PrintPass : Pass<"print"> {
Expand Down
39 changes: 39 additions & 0 deletions src/enzyme_ad/jax/Passes/PrintPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- PrintPass.cpp - Print the MLIR module ------------ //
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to print the MLIR module
//===----------------------------------------------------------------------===//

#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "src/enzyme_ad/jax/Passes/PassDetails.h"

#define DEBUG_TYPE "enzyme"

using namespace mlir;
using namespace mlir::enzyme;
using namespace enzyme;

namespace {
struct PrintPass
: public PrintPassBase<PrintPass> {

void runOnOperation() override {
llvm::errs() << *getOperation() << "\n";
}
};

} // end anonymous namespace

namespace mlir {
namespace enzyme {
std::unique_ptr<Pass> createPrintPass() {
return std::make_unique<PrintPass>();
}
} // namespace enzyme
} // namespace mlir

4 changes: 0 additions & 4 deletions src/enzyme_ad/jax/clang_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,6 @@ struct tensor<T, n0, N...>
f.setLinkage(Function::LinkageTypes::InternalLinkage);
}

llvm::errs() << " postlinkMod: " << *mod << "\n";

PipelineTuningOptions PTO;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
Expand Down Expand Up @@ -655,7 +653,5 @@ struct tensor<T, n0, N...>
}
}

llvm::errs() << " clangcompmod: " << *mod << "\n";

return mod;
}
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mhlo_text, parser_config);

if (!xla_runtime) {
mlir::PassManager pm(&context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
if (!mlir::succeeded(pm.run(*parsed_module))) {
throw pybind11::value_error("StableHLO => MHLO failed");
}
}

// Convert to XLA Computation.
xla::HloProto hlo_proto;
Expand Down
7 changes: 4 additions & 3 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class CpuKernel {
auto *cpu_executable = static_cast<xla::cpu::CpuExecutable *>(
local_executable->executable());
auto &assignment = cpu_executable->buffer_assignment();
if (!xla_runtime) {
size_t num_in = 0;
for (auto &buf2 : assignment.Allocations()) {
if (buf2.is_entry_computation_parameter()) {
Expand All @@ -133,6 +134,8 @@ class CpuKernel {
if (num_in != in_shapes.size()) {
std::string err_str;
llvm::raw_string_ostream ss(err_str);
ss << assignment.ToString() << "\n";
ss << source << "\n";
ss << " Number of mhlo inputs (" << num_in
<< ") != number of jax inputs (" << in_shapes.size() << "):\n";
ss << source << "\n";
Expand All @@ -157,6 +160,7 @@ class CpuKernel {
throw pybind11::value_error(ss.str());
}
}
}
source = stringbuf;
if (xla_runtime)
tmpBuf = 0;
Expand Down Expand Up @@ -207,7 +211,6 @@ class CpuKernel {
nullptr);
}
}
llvm::errs() << "linkMod: " << *linkMod << "\n";
}
if (xla_runtime) {
ss << " extern \"C\" void " << fn << "(void* exec";
Expand Down Expand Up @@ -835,12 +838,10 @@ class CpuKernel {
#endif
}

llvm::errs() << " str: " << ss.str() << "\n";
auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true,
pyargv_strs, llvm_ctx.get(), std::move(linkMod));
if (!mod)
throw pybind11::value_error("failed to compile C++");
llvm::errs() << " postmod: " << *mod << "\n";
return std::make_tuple(std::move(mod), std::move(llvm_ctx), out_off,
tmpBuf);
}
Expand Down
Loading

0 comments on commit 38e23d5

Please sign in to comment.