Skip to content

Commit 7e22f7e

Browse files
committed
add LoopInfo + LICM
1 parent 3909294 commit 7e22f7e

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

base/compiler/ssair/driver.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
131131
@timeit "domtree 2" domtree = construct_domtree(ir.cfg)
132132
@timeit "SROA" ir = getfield_elim_pass!(ir, domtree)
133133
#@Base.show ir.new_nodes
134+
ir = compact!(ir)
134135
#@Base.show ("after_sroa", ir)
136+
@timeit "loopinfo" loopinfo = construct_loopinfo(ir, domtree)
137+
if loopinfo !== nothing
138+
@timeit "licm" ir = licm_pass!(ir, loopinfo)
139+
end
140+
#@timeit "verify 3" verify_ir(ir)
141+
#@Base.show ("after_licm", ir)
135142
ir = adce_pass!(ir)
136143
#@Base.show ("after_adce", ir)
137144
@timeit "type lift" ir = type_lift_pass!(ir)

base/compiler/ssair/passes.jl

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,3 +1162,166 @@ function cfg_simplify!(ir::IRCode)
11621162
compact.active_result_bb = length(bb_starts)
11631163
return finish(compact)
11641164
end
1165+
1166+
struct LoopInfo
1167+
header::Int
1168+
latches::Vector{Int}
1169+
# exiting::Vector{Int}
1170+
# exits::Vector{Int}
1171+
blocks::Vector{Int}
1172+
end
1173+
# LoopInfo(header, latches, blocks) = LoopInfo(header, latches, Int[], Int[], blocks)
1174+
1175+
function construct_loopinfo(ir, domtree)
1176+
cfg = ir.cfg
1177+
1178+
# 1. find backedges
1179+
# Edge n -> h, where h dominates n
1180+
backedges = Pair{Int, Int}[]
1181+
for (n, bb) in enumerate(cfg.blocks)
1182+
for succ in bb.succs
1183+
if dominates(domtree, succ, n)
1184+
push!(backedges, n => succ)
1185+
end
1186+
end
1187+
end
1188+
isempty(backedges) && return nothing
1189+
1190+
loops = IdDict{Int, LoopInfo}()
1191+
for (n, h) in backedges
1192+
# merge loops that have the same header
1193+
if haskey(loops, h)
1194+
visited = BitSet(loops[h].blocks)
1195+
latches = loops[h].latches
1196+
else
1197+
visited = BitSet((h,))
1198+
latches = Int[]
1199+
end
1200+
push!(visited, n)
1201+
push!(latches, n)
1202+
1203+
# Create {n′ | n is reachable from n′ in CFG \ {h}} ∪ {h}
1204+
worklist = copy(cfg.blocks[n].preds)
1205+
while !isempty(worklist)
1206+
idx = pop!(worklist)
1207+
idx visited && continue
1208+
1209+
push!(visited, idx)
1210+
append!(worklist, cfg.blocks[idx].preds)
1211+
end
1212+
1213+
blocks = collect(visited)
1214+
# Assume sorted in CFG order
1215+
loops[h] = LoopInfo(h, latches, blocks)
1216+
end
1217+
1218+
# Find exiting and exit blocks
1219+
# LLVM calculates this on the fly
1220+
# for loop in values(loops)
1221+
# for bb in loop.blocks
1222+
# for succ in cfg.blocks[bb].succs
1223+
# succ ∈ loop.blocks && continue
1224+
# push!(loop.exiting, bb)
1225+
# push!(loop.exits, succ)
1226+
# end
1227+
# end
1228+
# end
1229+
1230+
# TODO: Loop nesting/Control tree
1231+
return loops
1232+
end
1233+
1234+
1235+
function licm_pass!(ir, loops)
1236+
cfg = ir.cfg
1237+
ir = IncrementalCompact(ir)
1238+
# TODO: Processing order, we should proceess innermost to outermost loops,
1239+
# but we need to maintain LoopInfo under CFG changes.
1240+
for (h, loop) in loops
1241+
# Find stmts that are invariant w.r.t this loop
1242+
invariant_stmts = Int[]
1243+
1244+
# TODO: Order to visit loops in: Innermost to outermost
1245+
for idx in loop.blocks
1246+
bb = cfg.blocks[idx]
1247+
for i in bb.stmts
1248+
stmt = ir[i]
1249+
if invariant_stmt(ir, loop, invariant_stmts, stmt)
1250+
# XXX: Need to account for, we either need
1251+
# to move the entire block or check ?reverse-dominance?
1252+
# if (x > 0)
1253+
# sqrt(x)
1254+
push!(invariant_stmts, i)
1255+
end
1256+
end
1257+
end
1258+
1259+
# XXX: Need to insert pre-header instead of dumping into predecessor
1260+
header = cfg.blocks[loop.header]
1261+
predecessors = filter(bb->bb loop.latches, header.preds)
1262+
@assert length(predecessors) == 1
1263+
# FIXME: Need to insert pre-header instead
1264+
pre_header = predecessors[1]
1265+
1266+
1267+
insertion_point = SSAValue(last(cfg.blocks[pre_header].stmts))
1268+
valmap = IdDict{SSAValue, Core.Compiler.AnySSAValue}()
1269+
typesmap = types(ir)
1270+
for idx in invariant_stmts
1271+
stmt = ir[idx]
1272+
new_stmt = Core.Compiler.ssamap(stmt) do val
1273+
if haskey(valmap, val)
1274+
return valmap[val]
1275+
else
1276+
return val
1277+
end
1278+
end
1279+
typ = typesmap[idx]
1280+
new_ssaval = insert_node!(ir, insertion_point, typ, stmt) # XXX: lineinfo
1281+
ir[idx] = new_ssaval
1282+
valmap[SSAValue(idx)] = new_ssaval
1283+
end
1284+
end
1285+
1286+
# Just run through the iterator without any processing
1287+
Core.Compiler.foreach(x -> nothing, ir) # x isa Pair{Int, Any}
1288+
return Core.Compiler.finish(ir)
1289+
end
1290+
1291+
cfg(ir::IRCode) = ir.cfg
1292+
cfg(compact::IncrementalCompact) = cfg(compact.ir)
1293+
1294+
function invariant_stmt(ir, loop, invariant_stmts, stmt)
1295+
if stmt isa Expr
1296+
return invariant_expr(ir, loop, invariant_stmts, stmt)
1297+
end
1298+
return invariant(ir, loop, invariant_stmts, stmt)
1299+
end
1300+
1301+
function invariant(ir, loop, invariant_stmts, stmt)
1302+
if stmt isa Argument || stmt isa GlobalRef || stmt isa QuoteNode || stmt isa Bool
1303+
return true
1304+
elseif stmt isa SSAValue
1305+
id = stmt.id
1306+
bb = block_for_inst(cfg(ir), id)
1307+
if bb loop.blocks
1308+
return true
1309+
end
1310+
return id invariant_stmts
1311+
end
1312+
1313+
# Check for pure / not side-effecting,
1314+
# since we hoist into pre-header throwing is okay.
1315+
if stmt isa Core.MethodInstance
1316+
return stmt.def.pure
1317+
end
1318+
return false
1319+
end
1320+
1321+
function invariant_expr(ir, loop, invariant_stmts, stmt)
1322+
invar = true
1323+
for useref in userefs(stmt)
1324+
invar &= invariant(ir, loop, invariant_stmts, useref[])
1325+
end
1326+
return invar
1327+
end

0 commit comments

Comments
 (0)