Skip to content

Commit

Permalink
Revert "Use fast math function for tl.math.log as exp (#4723)" (#4779)
Browse files Browse the repository at this point in the history
This reverts commit 84fe9da.
  • Loading branch information
ThomasRaoux committed Sep 21, 2024
1 parent 93c2027 commit 3a647f0
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ struct ElementwiseOpConversion
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {rewriter.create<DestOp>(loc, elemTy, operands[0], op->getAttrs())};
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
adaptor.getAttributes().getValue())};
}
};

Expand Down
18 changes: 6 additions & 12 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <pybind11/functional.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -135,11 +135,6 @@ void outputWarning(Location loc, const std::string &msg) {
/*stack_level=*/2);
}

template <typename OpTy> OpTy approxMath(OpTy op) {
op.setFastmath(arith::FastMathFlags::afn);
return op;
}

} // anonymous namespace

/*****************************************************************************/
Expand Down Expand Up @@ -1453,28 +1448,27 @@ void init_triton_ir(py::module &&m) {
})
.def("create_exp",
[](TritonOpBuilder &self, Value &val) -> Value {
return approxMath(self.create<math::ExpOp>(val));
return self.create<math::ExpOp>(val);
})
.def("create_exp2",
[](TritonOpBuilder &self, Value &val) -> Value {
return approxMath(self.create<math::Exp2Op>(val));
return self.create<math::Exp2Op>(val);
})
.def("create_cos",
[](TritonOpBuilder &self, Value &val) -> Value {
return approxMath(self.create<math::CosOp>(val));
return self.create<math::CosOp>(val);
})
.def("create_sin",
[](TritonOpBuilder &self, Value &val) -> Value {
return approxMath(self.create<math::SinOp>(val));
return self.create<math::SinOp>(val);
})
.def("create_log",
[](TritonOpBuilder &self, Value &val) -> Value {
// TODO: switch to approxMath.
return self.create<math::LogOp>(val);
})
.def("create_log2",
[](TritonOpBuilder &self, Value &val) -> Value {
return approxMath(self.create<math::Log2Op>(val));
return self.create<math::Log2Op>(val);
})
.def("create_erf",
[](TritonOpBuilder &self, Value &val) -> Value {
Expand Down
7 changes: 1 addition & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4382,14 +4382,9 @@ def kernel(X, Y, BLOCK: tl.constexpr):
x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device))
y = torch.zeros(shape, dtype=torch.float32, device=device)

k = kernel[(1, )](x, y, BLOCK=shape[0])
kernel[(1, )](x, y, BLOCK=shape[0])
torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3)

if func_str in ['log2'] and is_cuda():
assert 'lg2.approx.ftz.f32' in k.asm['ptx']
if func_str in ['exp', 'exp2'] and is_cuda():
assert 'ex2.approx.ftz.f32' in k.asm['ptx']


# -----------------------
# test inline asm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,32 @@ struct TruncFOpConversion
}
};

struct ExpOpConversionApprox
: ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox> {
using Base = ElementwiseOpConversionBase<math::ExpOp, ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// For non-FP32 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() != 32)
return {};

const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e));

PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
return {ptxBuilder.launch(rewriter, loc, f32_ty, false)};
}
};

struct ClampFOpConversion
: ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion> {
using Base = ElementwiseOpConversionBase<ClampFOp, ClampFOpConversion>;
Expand Down Expand Up @@ -925,6 +951,11 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns(
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis,
computeCapability, benefit);

// ExpOpConversionApprox will try using ex2.approx if the input type is
// FP32. For other input types, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, axisInfoAnalysis, benefit);
bool hwNanPropagationSupported = computeCapability >= 80;
mlir::triton::populateMinMaxFOpToLLVMPattern(
typeConverter, patterns, axisInfoAnalysis, hwNanPropagationSupported,
Expand Down

0 comments on commit 3a647f0

Please sign in to comment.