-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
189 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
|
||
def : HLODerivative<"AddOp", (Op $x, $y), | ||
[ | ||
(DiffeRet), | ||
(DiffeRet), | ||
] | ||
>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
include "Common.td" | ||
|
||
def : MLIRDerivative<"mhlo", "AddOp", (Op $x, $y), | ||
[ | ||
(DiffeRet), | ||
(DiffeRet), | ||
] | ||
>; | ||
class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps>; | ||
|
||
class HLOInst<string m> : Inst<m, "mhlo">; | ||
|
||
def Add : HLOInst<"AddOp">; | ||
def Sub : HLOInst<"mhlo::SubOp">; | ||
def Neg : HLOInst<"mhlo::NegOp">; | ||
def Mul : HLOInst<"mhlo::MulOp">; | ||
def Div : HLOInst<"mhlo::DivOp">; | ||
def Rem : HLOInst<"mhlo::RemOp">; | ||
|
||
include "HLODerivatives.td" |
40 changes: 40 additions & 0 deletions
40
src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file contains the external model implementation of the automatic | ||
// differentiation op interfaces for the upstream MLIR arithmetic dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" | ||
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h" | ||
#include "Enzyme/MLIR/Interfaces/GradientUtils.h" | ||
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" | ||
#include "mlir/IR/DialectRegistry.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
#include "Dialect/Ops.h" | ||
#include "mlir/IR/TypeSupport.h" | ||
|
||
#include "stablehlo/dialect/StablehloOps.h" | ||
|
||
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::enzyme; | ||
|
||
namespace { | ||
#include "src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc" | ||
} // namespace | ||
|
||
void mlir::enzyme::registerStableHLODialectAutoDiffInterface( | ||
DialectRegistry ®istry) { | ||
registry.addExtension(+[](MLIRContext *context, stablehlo::StablehloDialect *) { | ||
registerInterfaces(context); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
include "Common.td" | ||
|
||
class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps>; | ||
|
||
class HLOInst<string m> : Inst<m, "mhlo">; | ||
|
||
def Add : HLOInst<"AddOp">; | ||
def Sub : HLOInst<"mhlo::SubOp">; | ||
def Neg : HLOInst<"mhlo::NegOp">; | ||
def Mul : HLOInst<"mhlo::MulOp">; | ||
def Div : HLOInst<"mhlo::DivOp">; | ||
def Rem : HLOInst<"mhlo::RemOp">; | ||
|
||
include "HLODerivatives.td" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
//===- PrintPass.cpp - Print the MLIR module ------------ // | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements a pass to print the MLIR module | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/enzyme_ad/jax/Passes/Passes.h" | ||
#include "src/enzyme_ad/jax/Passes/PassDetails.h" | ||
|
||
#define DEBUG_TYPE "enzyme" | ||
|
||
using namespace mlir; | ||
using namespace mlir::enzyme; | ||
using namespace enzyme; | ||
|
||
namespace { | ||
struct PrintPass | ||
: public PrintPassBase<PrintPass> { | ||
|
||
void runOnOperation() override { | ||
llvm::errs() << *getOperation() << "\n"; | ||
} | ||
}; | ||
|
||
} // end anonymous namespace | ||
|
||
namespace mlir { | ||
namespace enzyme { | ||
std::unique_ptr<Pass> createPrintPass() { | ||
return std::make_unique<PrintPass>(); | ||
} | ||
} // namespace enzyme | ||
} // namespace mlir | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.