Skip to content

Commit

Permalink
Raise powi and absi to the math dialect
Browse files Browse the repository at this point in the history
Closes #246.
ftynse authored and wsmoses committed Jan 23, 2025
1 parent 8404cda commit faf1125
Showing 2 changed files with 43 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp
Original file line number Diff line number Diff line change
@@ -171,17 +171,14 @@ class CallToOpRaising : public OpRewritePattern<LLVM::CallOp> {
};
} // namespace

template <typename TargetOp>
template <typename TargetOp, typename Arg, typename... Args>
static void populateOpPatterns(MLIRContext *context,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f32ApproxFunc = "",
StringRef f16Func = "") {
patterns.add<CallToOpRaising<TargetOp>>(context, f32Func);
patterns.add<CallToOpRaising<TargetOp>>(context, f64Func);
if (!f32ApproxFunc.empty())
patterns.add<CallToOpRaising<TargetOp>>(context, f32ApproxFunc);
if (!f16Func.empty())
patterns.add<CallToOpRaising<TargetOp>>(context, f16Func);
RewritePatternSet &patterns, Arg &&arg,
Args &&...args) {
patterns.add<CallToOpRaising<TargetOp>>(context, std::forward<Arg>(arg));
if constexpr (sizeof...(Args) != 0)
populateOpPatterns<TargetOp>(context, patterns,
std::forward<Args>(args)...);
}

namespace {
@@ -399,6 +396,9 @@ void mlir::enzyme::populateLibDeviceFuncsToOpsPatterns(
"__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
populateOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif",
"__nv_powi");
populateOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
}

void populateLLVMToMathPatterns(MLIRContext *context,
33 changes: 33 additions & 0 deletions test/lit_tests/libdevice_raise.mlir
Original file line number Diff line number Diff line change
@@ -649,5 +649,38 @@ module {
llvm.return
}
}

gpu.module @test_module_46 {
llvm.func @__nv_powif(f32, i32) -> f32
llvm.func @__nv_powi(f64, i32) -> f64
llvm.func @gpu_powi(%arg0: f32, %arg1: f64, %arg2: i32) -> !llvm.struct<(f32, f64)> attributes {llvm.emit_c_interface} {
// CHECK-COUNT-2: math.fpowi
%0 = llvm.call @__nv_powif(%arg0, %arg2) : (f32, i32) -> f32
%1 = llvm.call @__nv_powi(%arg1, %arg2) : (f64, i32) -> f64
%2 = llvm.mlir.undef : !llvm.struct<(f32, f64)>
%3 = llvm.insertvalue %0, %2[0] : !llvm.struct<(f32, f64)>
%4 = llvm.insertvalue %1, %3[1] : !llvm.struct<(f32, f64)>
llvm.return %4 : !llvm.struct<(f32, f64)>
}
llvm.func @_mlir_ciface_gpu_powi(%arg0: !llvm.ptr, %arg1: f32, %arg2: f64, %arg3: i32) attributes {llvm.emit_c_interface} {
%0 = llvm.call @gpu_powi(%arg1, %arg2, %arg3) : (f32, f64, i32) -> !llvm.struct<(f32, f64)>
llvm.store %0, %arg0 : !llvm.struct<(f32, f64)>, !llvm.ptr
llvm.return
}
}

gpu.module @test_module_47 {
llvm.func @__nv_abs(i32) -> i32
llvm.func @gpu_abs(%arg0: i32) -> i32 attributes {llvm.emit_c_interface} {
// CHECK: math.absi
%0 = llvm.call @__nv_abs(%arg0) : (i32) -> i32
llvm.return %0 : i32
}
llvm.func @_mlir_ciface_gpu_abs(%arg0 : !llvm.ptr, %arg1: i32) attributes {llvm.emit_c_interface} {
%0 = llvm.call @gpu_abs(%arg1) : (i32) -> i32
llvm.store %0, %arg0 : i32, !llvm.ptr
llvm.return
}
}
}

0 comments on commit faf1125

Please sign in to comment.