diff --git a/WORKSPACE b/WORKSPACE index 4fae68634..7db8f9f40 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -30,7 +30,7 @@ http_archive( load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) -XLA_COMMIT = "abe31d7bb7d49807be0acc5da33647ffa759741b" +XLA_COMMIT = "6ee7005b0dbe29ba0cd077a690db1555ec6de346" XLA_SHA256 = "" http_archive( @@ -61,7 +61,7 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() ENZYME_COMMIT = "9384fe20caec02bd30f302e32f4f1c1f7ccb7d9d" -ENZYME_SHA256 = "" +ENZYME_SHA256 = "a0846fab3a2927d84ab5c9a51e539e93a9a7d6c6ec5d87722f24074fefd0c3f8" http_archive( name = "enzyme", diff --git a/src/enzyme_ad/jax/Passes/ArithRaising.cpp b/src/enzyme_ad/jax/Passes/ArithRaising.cpp new file mode 100644 index 000000000..0449a5a44 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/ArithRaising.cpp @@ -0,0 +1,106 @@ +//===- EnzymeWrapPass.cpp - Replace calls with their derivatives ------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to create wrapper functions which differentiate +// ops. +//===----------------------------------------------------------------------===// + +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; + +namespace { +struct ArithRaisingPass + : public ArithRaisingPassBase { + + void runOnOperation() override { + + Operation *symbolOp = nullptr; + if (infn != "") + symbolOp = symbolTable.lookupSymbolIn( + getOperation(), StringAttr::get(getOperation()->getContext(), infn)); + else { + for (auto &op : getOperation()->getRegion(0).front()) { + auto fn = dyn_cast(symbolOp); + if (!fn) + continue; + assert(symbolOp == nullptr); + symbolOp = &op; + } + } + auto fn = cast(symbolOp); + SmallVector split; + StringRef(argTys.getValue().data(), argTys.getValue().size()) + .split(split, ','); + std::vector constants; + for (auto &str : split) { + if (str == "enzyme_dup") + constants.push_back(DIFFE_TYPE::DUP_ARG); + else if (str == "enzyme_const") + constants.push_back(DIFFE_TYPE::CONSTANT); + else if (str == "enzyme_dupnoneed") + constants.push_back(DIFFE_TYPE::DUP_NONEED); + else if (str == "enzyme_out") + constants.push_back(DIFFE_TYPE::OUT_DIFF); + else { + llvm::errs() << "unknown argument activity to parse, found: '" << str + << "'\n"; + assert(0 && " unknown constant"); + } + } + + DIFFE_TYPE retType = retTy.getValue(); + MTypeAnalysis TA; + auto type_args = TA.getAnalyzedTypeInfo(fn); + + bool freeMemory = true; + size_t width = 1; + + std::vector volatile_args; + for (auto &a : fn.getFunctionBody().getArguments()) { + (void)a; + volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); + } + + FunctionOpInterface newFunc = Logic.CreateForwardDiff( + fn, retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + if (outfn == "") { + fn->erase(); + } else { + SymbolTable::setSymbolName(cast(newFunc), + (std::string)outfn); + } + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createArithRaisingPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir + diff --git a/src/enzyme_ad/jax/Passes/PassDetails.h b/src/enzyme_ad/jax/Passes/PassDetails.h new file mode 100644 index 000000000..63a97aad7 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/PassDetails.h @@ -0,0 +1,39 @@ +//===- PassDetails.h - Enzyme pass class details ----------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Stuff shared between the different polygeist passes. +// +//===----------------------------------------------------------------------===// + +// clang-tidy seems to expect the absolute path in the header guard on some +// systems, so just disable it. +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef DIALECT_ENZYMEXLA_TRANSFORMS_PASSDETAILS_H +#define DIALECT_ENZYMEXLA_TRANSFORMS_PASSDETAILS_H + +#include "Dialect/Ops.h" +#include "Passes/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +class FunctionOpInterface; +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); +namespace enzyme { + +class EnzymeDialect; + +#define GEN_PASS_CLASSES +#include "Passes/Passes.h.inc" + +} // namespace enzyme +} // namespace mlir + +#endif // DIALECT_ENZYME_TRANSFORMS_PASSDETAILS_H diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h new file mode 100644 index 000000000..cd53d4c17 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -0,0 +1,66 @@ +//===- Passes.h - Enzyme pass include header -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef ENZYMEXLA_PASSES_H +#define ENZYMEXLA_PASSES_H + +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Pass/Pass.h" +#include + +#include "Enzyme/MLIR/Dialect/Dialect.h" + +namespace mlir { +class PatternRewriter; +class RewritePatternSet; +class DominanceInfo; +namespace enzyme { +std::unique_ptr createArithRaisingPass(); +} // namespace enzyme +} // namespace mlir + +namespace mlir { +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); + +namespace mhlo { +class MhloDialect; +} // end namespace mhlo + +namespace arith { +class ArithDialect; +} // end namespace arith + +namespace cf { +class ControlFlowDialect; +} // end namespace cf + +namespace scf { +class SCFDialect; +} // end namespace scf + +namespace memref { +class MemRefDialect; +} // end namespace memref + +namespace func { +class FuncDialect; +} + +class AffineDialect; +namespace LLVM { +class LLVMDialect; +} + +#define GEN_PASS_REGISTRATION +#include "Passes/Passes.h.inc" + +} // end namespace mlir + +#endif // ENZYMEXLA_PASSES_H + diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td new file mode 100644 index 000000000..bbec8f6cb --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -0,0 +1,23 @@ +//===- Passes.td - EnzymeXLA pass tablegen macros ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYMEXLA_PASSES +#define ENZYMEXLA_PASSES + +include "mlir/Pass/PassBase.td" + +def ArithRaisePass : Pass<"arith-raise"> { + let summary = "Raise Arith to mhlo"; + let dependentDialects = [ + "arith::ArithDialect", + "mhlo::MhloDialect" + ]; + let constructor = "mlir::enzyme::createArithRaisePass()"; +} + +#endif