Skip to content

Commit

Permalink
bump commits
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 14, 2024
1 parent b96dcf4 commit fd1f337
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 2 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ http_archive(
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
llvm_configure(name = "llvm-project", targets = LLVM_TARGETS)

XLA_COMMIT = "abe31d7bb7d49807be0acc5da33647ffa759741b"
XLA_COMMIT = "6ee7005b0dbe29ba0cd077a690db1555ec6de346"
XLA_SHA256 = ""

http_archive(
Expand Down Expand Up @@ -61,7 +61,7 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen
pip_install_dependencies()

ENZYME_COMMIT = "9384fe20caec02bd30f302e32f4f1c1f7ccb7d9d"
ENZYME_SHA256 = ""
ENZYME_SHA256 = "a0846fab3a2927d84ab5c9a51e539e93a9a7d6c6ec5d87722f24074fefd0c3f8"

http_archive(
name = "enzyme",
Expand Down
106 changes: 106 additions & 0 deletions src/enzyme_ad/jax/Passes/ArithRaising.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===- EnzymeWrapPass.cpp - Replace calls with their derivatives ------------ //
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to create wrapper functions which differentiate
// ops.
//===----------------------------------------------------------------------===//

#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "src/enzyme_ad/jax/Passes/PassDetails.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"

#define DEBUG_TYPE "enzyme"

using namespace mlir;
using namespace mlir::enzyme;
using namespace enzyme;

namespace {
struct ArithRaisingPass
: public ArithRaisingPassBase<ArithRaisingPass> {

void runOnOperation() override {

Operation *symbolOp = nullptr;
if (infn != "")
symbolOp = symbolTable.lookupSymbolIn<Operation *>(
getOperation(), StringAttr::get(getOperation()->getContext(), infn));
else {
for (auto &op : getOperation()->getRegion(0).front()) {
auto fn = dyn_cast<FunctionOpInterface>(symbolOp);
if (!fn)
continue;
assert(symbolOp == nullptr);
symbolOp = &op;
}
}
auto fn = cast<FunctionOpInterface>(symbolOp);
SmallVector<StringRef, 1> split;
StringRef(argTys.getValue().data(), argTys.getValue().size())
.split(split, ',');
std::vector<DIFFE_TYPE> constants;
for (auto &str : split) {
if (str == "enzyme_dup")
constants.push_back(DIFFE_TYPE::DUP_ARG);
else if (str == "enzyme_const")
constants.push_back(DIFFE_TYPE::CONSTANT);
else if (str == "enzyme_dupnoneed")
constants.push_back(DIFFE_TYPE::DUP_NONEED);
else if (str == "enzyme_out")
constants.push_back(DIFFE_TYPE::OUT_DIFF);
else {
llvm::errs() << "unknown argument activity to parse, found: '" << str
<< "'\n";
assert(0 && " unknown constant");
}
}

DIFFE_TYPE retType = retTy.getValue();
MTypeAnalysis TA;
auto type_args = TA.getAnalyzedTypeInfo(fn);

bool freeMemory = true;
size_t width = 1;

std::vector<bool> volatile_args;
for (auto &a : fn.getFunctionBody().getArguments()) {
(void)a;
volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
}

FunctionOpInterface newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
if (outfn == "") {
fn->erase();
} else {
SymbolTable::setSymbolName(cast<FunctionOpInterface>(newFunc),
(std::string)outfn);
}
}
};

} // end anonymous namespace

namespace mlir {
namespace enzyme {
std::unique_ptr<Pass> createArithRaisingPass() {
return std::make_unique<ArithRaisingPass>();
}
} // namespace enzyme
} // namespace mlir

39 changes: 39 additions & 0 deletions src/enzyme_ad/jax/Passes/PassDetails.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- PassDetails.h - Enzyme pass class details ----------------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Stuff shared between the different polygeist passes.
//
//===----------------------------------------------------------------------===//

// clang-tidy seems to expect the absolute path in the header guard on some
// systems, so just disable it.
// NOLINTNEXTLINE(llvm-header-guard)
#ifndef DIALECT_ENZYMEXLA_TRANSFORMS_PASSDETAILS_H
#define DIALECT_ENZYMEXLA_TRANSFORMS_PASSDETAILS_H

#include "Dialect/Ops.h"
#include "Passes/Passes.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class FunctionOpInterface;
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace enzyme {

class EnzymeDialect;

#define GEN_PASS_CLASSES
#include "Passes/Passes.h.inc"

} // namespace enzyme
} // namespace mlir

#endif // DIALECT_ENZYME_TRANSFORMS_PASSDETAILS_H
66 changes: 66 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- Passes.h - Enzyme pass include header -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ENZYMEXLA_PASSES_H
#define ENZYMEXLA_PASSES_H

#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Pass/Pass.h"
#include <memory>

#include "Enzyme/MLIR/Dialect/Dialect.h"

namespace mlir {
class PatternRewriter;
class RewritePatternSet;
class DominanceInfo;
namespace enzyme {
std::unique_ptr<Pass> createArithRaisingPass();
} // namespace enzyme
} // namespace mlir

namespace mlir {
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);

namespace mhlo {
class MhloDialect;
} // end namespace mhlo

namespace arith {
class ArithDialect;
} // end namespace arith

namespace cf {
class ControlFlowDialect;
} // end namespace cf

namespace scf {
class SCFDialect;
} // end namespace scf

namespace memref {
class MemRefDialect;
} // end namespace memref

namespace func {
class FuncDialect;
}

class AffineDialect;
namespace LLVM {
class LLVMDialect;
}

#define GEN_PASS_REGISTRATION
#include "Passes/Passes.h.inc"

} // end namespace mlir

#endif // ENZYMEXLA_PASSES_H

23 changes: 23 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- Passes.td - EnzymeXLA pass tablegen macros ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ENZYMEXLA_PASSES
#define ENZYMEXLA_PASSES

include "mlir/Pass/PassBase.td"

def ArithRaisePass : Pass<"arith-raise"> {
let summary = "Raise Arith to mhlo";
let dependentDialects = [
"arith::ArithDialect",
"mhlo::MhloDialect"
];
let constructor = "mlir::enzyme::createArithRaisePass()";
}

#endif

0 comments on commit fd1f337

Please sign in to comment.