Skip to content

Commit 5e21673

Browse files
authored
Merge pull request #22262 from JuliaLang/yyc/codegen/muladd
Add a custom LLVM pass to replace fastmath multiple and add with muladd
2 parents e5cc4d7 + b2ee098 commit 5e21673

File tree

7 files changed

+176
-6
lines changed

7 files changed

+176
-6
lines changed

base/math.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,8 +941,11 @@ mod2pi(x) = rem2pi(x,RoundDown)
941941
"""
942942
muladd(x, y, z)
943943
944-
Combined multiply-add, computes `x*y+z` in an efficient manner. This may on some systems be
945-
equivalent to `x*y+z`, or to `fma(x,y,z)`. `muladd` is used to improve performance.
944+
Combined multiply-add, computes `x*y+z` allowing the add and multiply to be contracted with
945+
each other or ones from other `muladd` and `@fastmath` to form `fma`
946+
if the transformation can improve performance.
947+
The result can be different on different machines and can also be different on the same machine
948+
due to constant propagation or other optimizations.
946949
See [`fma`](@ref).
947950
948951
# Example

src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ endif
5050
LLVMLINK :=
5151

5252
ifeq ($(JULIACODEGEN),LLVM)
53-
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-late-gc-lowering llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces cgmemmgr
53+
SRCS += codegen jitlayers disasm debuginfo llvm-simdloop llvm-ptls llvm-muladd llvm-late-gc-lowering llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces cgmemmgr
5454
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
5555
LLVM_LIBS := all
5656
ifeq ($(USE_POLLY),1)

src/intrinsics.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,17 +704,23 @@ static Value *emit_checked_srem_int(jl_codectx_t &ctx, Value *x, Value *den)
704704
struct math_builder {
705705
IRBuilder<> &ctxbuilder;
706706
FastMathFlags old_fmf;
707-
math_builder(jl_codectx_t &ctx, bool always_fast = false)
707+
math_builder(jl_codectx_t &ctx, bool always_fast = false, bool contract = false)
708708
: ctxbuilder(ctx.builder),
709709
old_fmf(ctxbuilder.getFastMathFlags())
710710
{
711+
FastMathFlags fmf;
711712
if (jl_options.fast_math != JL_OPTIONS_FAST_MATH_OFF &&
712713
(always_fast ||
713714
jl_options.fast_math == JL_OPTIONS_FAST_MATH_ON)) {
714-
FastMathFlags fmf;
715715
fmf.setUnsafeAlgebra();
716-
ctxbuilder.setFastMathFlags(fmf);
717716
}
717+
#if JL_LLVM_VERSION >= 50000
718+
if (contract)
719+
fmf.setAllowContract(true);
720+
#else
721+
assert(!contract);
722+
#endif
723+
ctxbuilder.setFastMathFlags(fmf);
718724
}
719725
IRBuilder<>& operator()() const { return ctxbuilder; }
720726
~math_builder() {
@@ -913,10 +919,18 @@ static Value *emit_untyped_intrinsic(jl_codectx_t &ctx, intrinsic f, Value **arg
913919
return ctx.builder.CreateCall(fmaintr, {x, y, z});
914920
}
915921
case muladd_float: {
922+
#if JL_LLVM_VERSION >= 50000
923+
// LLVM 5.0 can create FMA in the backend for contractable fmul and fadd
924+
// Emitting fmul and fadd here since they are easier for other LLVM passes to
925+
// optimize.
926+
auto mathb = math_builder(ctx, false, true);
927+
return mathb().CreateFAdd(mathb().CreateFMul(x, y), z);
928+
#else
916929
assert(y->getType() == x->getType());
917930
assert(z->getType() == y->getType());
918931
Value *muladdintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd, makeArrayRef(t));
919932
return ctx.builder.CreateCall(muladdintr, {x, y, z});
933+
#endif
920934
}
921935

922936
case checked_sadd_int:

src/jitlayers.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level)
242242
PM->add(createDeadCodeEliminationPass());
243243
PM->add(createLowerPTLSPass(imaging_mode));
244244
#endif
245+
PM->add(createCombineMulAddPass());
245246
}
246247

247248
extern "C" JL_DLLEXPORT

src/jitlayers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ extern JuliaOJIT *jl_ExecutionEngine;
202202
JL_DLLEXPORT extern LLVMContext jl_LLVMContext;
203203

204204
Pass *createLowerPTLSPass(bool imaging_mode);
205+
Pass *createCombineMulAddPass();
205206
Pass *createLateLowerGCFramePass();
206207
Pass *createLowerExcHandlersPass();
207208
Pass *createGCInvariantVerifierPass(bool Strong);

src/llvm-muladd.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
#define DEBUG_TYPE "combine_muladd"
4+
#undef DEBUG
5+
#include "llvm-version.h"
6+
7+
#include <llvm/IR/Value.h>
8+
#include <llvm/IR/Function.h>
9+
#include <llvm/IR/Instructions.h>
10+
#include <llvm/IR/IntrinsicInst.h>
11+
#include <llvm/IR/Module.h>
12+
#include <llvm/IR/Operator.h>
13+
#include <llvm/IR/IRBuilder.h>
14+
#include <llvm/Pass.h>
15+
#include <llvm/Support/Debug.h>
16+
#include "fix_llvm_assert.h"
17+
18+
#include "julia.h"
19+
20+
using namespace llvm;
21+
22+
/**
23+
* Combine
24+
* ```
25+
* %v0 = fmul ... %a, %b
26+
* %v = fadd fast ... %v0, %c
27+
* ```
28+
* to
29+
* `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)`
30+
* when `%v0` has no other use
31+
*/
32+
33+
struct CombineMulAdd : public FunctionPass {
34+
static char ID;
35+
CombineMulAdd() : FunctionPass(ID)
36+
{}
37+
38+
private:
39+
bool runOnFunction(Function &F) override;
40+
};
41+
42+
// Return true if this function shouldn't be called again on the other operand
43+
// This will always return false on LLVM 5.0+
44+
static bool checkCombine(Module *m, Instruction *addOp, Value *maybeMul, Value *addend,
45+
bool negadd, bool negres)
46+
{
47+
auto mulOp = dyn_cast<Instruction>(maybeMul);
48+
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
49+
return false;
50+
if (!mulOp->hasOneUse())
51+
return false;
52+
#if JL_LLVM_VERSION >= 50000
53+
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
54+
auto fmf = mulOp->getFastMathFlags();
55+
fmf.setAllowContract(true);
56+
mulOp->copyFastMathFlags(fmf);
57+
return false;
58+
#else
59+
IRBuilder<> builder(m->getContext());
60+
builder.SetInsertPoint(addOp);
61+
auto mul1 = mulOp->getOperand(0);
62+
auto mul2 = mulOp->getOperand(1);
63+
Value *muladdf = Intrinsic::getDeclaration(m, Intrinsic::fmuladd, addOp->getType());
64+
if (negadd) {
65+
auto newaddend = builder.CreateFNeg(addend);
66+
// Might be a const
67+
if (auto neginst = dyn_cast<Instruction>(newaddend))
68+
neginst->setHasUnsafeAlgebra(true);
69+
addend = newaddend;
70+
}
71+
Instruction *newv = builder.CreateCall(muladdf, {mul1, mul2, addend});
72+
newv->setHasUnsafeAlgebra(true);
73+
if (negres) {
74+
// Shouldn't be a constant
75+
newv = cast<Instruction>(builder.CreateFNeg(newv));
76+
newv->setHasUnsafeAlgebra(true);
77+
}
78+
addOp->replaceAllUsesWith(newv);
79+
addOp->eraseFromParent();
80+
mulOp->eraseFromParent();
81+
return true;
82+
#endif
83+
}
84+
85+
bool CombineMulAdd::runOnFunction(Function &F)
86+
{
87+
Module *m = F.getParent();
88+
for (auto &BB: F) {
89+
for (auto it = BB.begin(); it != BB.end();) {
90+
auto &I = *it;
91+
it++;
92+
switch (I.getOpcode()) {
93+
case Instruction::FAdd: {
94+
if (!I.hasUnsafeAlgebra())
95+
continue;
96+
checkCombine(m, &I, I.getOperand(0), I.getOperand(1), false, false) ||
97+
checkCombine(m, &I, I.getOperand(1), I.getOperand(0), false, false);
98+
break;
99+
}
100+
case Instruction::FSub: {
101+
if (!I.hasUnsafeAlgebra())
102+
continue;
103+
checkCombine(m, &I, I.getOperand(0), I.getOperand(1), true, false) ||
104+
checkCombine(m, &I, I.getOperand(1), I.getOperand(0), true, true);
105+
break;
106+
}
107+
default:
108+
break;
109+
}
110+
}
111+
}
112+
return true;
113+
}
114+
115+
char CombineMulAdd::ID = 0;
116+
static RegisterPass<CombineMulAdd> X("CombineMulAdd", "Combine mul and add to muladd",
117+
false /* Only looks at CFG */,
118+
false /* Analysis Pass */);
119+
120+
Pass *createCombineMulAddPass()
121+
{
122+
return new CombineMulAdd();
123+
}

test/llvmpasses/muladd.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: opt -load libjulia.so -CombineMulAdd -S %s | FileCheck %s
2+
3+
define double @fast_muladd1(double %a, double %b, double %c) {
4+
top:
5+
; CHECK: {{contract|fmuladd}}
6+
%v1 = fmul double %a, %b
7+
%v2 = fadd fast double %v1, %c
8+
; CHECK: ret double
9+
ret double %v2
10+
}
11+
12+
define double @fast_mulsub1(double %a, double %b, double %c) {
13+
top:
14+
; CHECK: {{contract|fmuladd}}
15+
%v1 = fmul double %a, %b
16+
%v2 = fsub fast double %v1, %c
17+
; CHECK: ret double
18+
ret double %v2
19+
}
20+
21+
define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
22+
top:
23+
; CHECK: {{contract|fmuladd}}
24+
%v1 = fmul <2 x double> %a, %b
25+
%v2 = fsub fast <2 x double> %c, %v1
26+
; CHECK: ret <2 x double>
27+
ret <2 x double> %v2
28+
}

0 commit comments

Comments
 (0)