19
19
#include " llvm/ADT/PostOrderIterator.h"
20
20
#include " llvm/IR/Function.h"
21
21
#include " llvm/IR/Instructions.h"
22
- #include " llvm/IR/IntrinsicInst.h"
23
22
#include " llvm/IR/Module.h"
24
23
#include " llvm/Support/Debug.h"
25
24
@@ -44,50 +43,76 @@ void VisitorTemplate::setStrategy(VisitorStrategy strategy) {
44
43
m_strategy = strategy;
45
44
}
46
45
47
- void VisitorTemplate::add (VisitorKey key, VisitorCallback *fn,
48
- VisitorCallbackData data,
49
- VisitorHandler::Projection projection) {
50
- VisitorHandler handler;
51
- handler.callback = fn;
52
- handler.data = data;
53
- handler.projection = projection;
54
-
55
- m_handlers.emplace_back (handler);
46
+ void VisitorTemplate::storeHandlersInOpMap (
47
+ const VisitorKey &key, unsigned handlerIdx,
48
+ VisitorCallbackType visitorCallbackTy) {
49
+ const auto HandlerList =
50
+ [&](const OpDescription &opDescription) -> llvm::SmallVector<unsigned > & {
51
+ if (visitorCallbackTy == VisitorCallbackType::PreVisit)
52
+ return m_opMap[opDescription].PreVisitHandlers ;
56
53
57
- const unsigned handlerIdx = m_handlers.size () - 1 ;
54
+ return m_opMap[opDescription].VisitHandlers ;
55
+ };
58
56
59
57
if (key.m_kind == VisitorKey::Kind::Intrinsic) {
60
- m_opMap[ OpDescription::fromIntrinsic (key.m_intrinsicId )]. push_back (
61
- handlerIdx);
58
+ HandlerList ( OpDescription::fromIntrinsic (key.m_intrinsicId ))
59
+ . push_back ( handlerIdx);
62
60
} else if (key.m_kind == VisitorKey::Kind::OpDescription) {
63
61
const OpDescription *opDesc = key.m_description ;
64
62
65
63
if (opDesc->isCoreOp ()) {
66
64
for (const unsigned op : opDesc->getOpcodes ())
67
- m_opMap[ OpDescription::fromCoreOp (op)] .push_back (handlerIdx);
65
+ HandlerList ( OpDescription::fromCoreOp (op)) .push_back (handlerIdx);
68
66
} else if (opDesc->isIntrinsic ()) {
69
67
for (const unsigned op : opDesc->getOpcodes ())
70
- m_opMap[ OpDescription::fromIntrinsic (op)] .push_back (handlerIdx);
68
+ HandlerList ( OpDescription::fromIntrinsic (op)) .push_back (handlerIdx);
71
69
} else {
72
- m_opMap[ *opDesc] .push_back (handlerIdx);
70
+ HandlerList ( *opDesc) .push_back (handlerIdx);
73
71
}
74
72
} else if (key.m_kind == VisitorKey::Kind::OpSet) {
75
73
const OpSet *opSet = key.m_set ;
76
74
75
+ if (visitorCallbackTy == VisitorCallbackType::PreVisit && opSet->empty ()) {
76
+ // This adds a handler for every stored op.
77
+ // Note: should be used with caution.
78
+ for (auto it : m_opMap)
79
+ it.second .PreVisitHandlers .push_back (handlerIdx);
80
+
81
+ return ;
82
+ }
83
+
77
84
for (unsigned opcode : opSet->getCoreOpcodes ())
78
- m_opMap[ OpDescription::fromCoreOp (opcode)] .push_back (handlerIdx);
85
+ HandlerList ( OpDescription::fromCoreOp (opcode)) .push_back (handlerIdx);
79
86
80
87
for (unsigned intrinsicID : opSet->getIntrinsicIDs ())
81
- m_opMap[OpDescription::fromIntrinsic (intrinsicID)].push_back (handlerIdx);
88
+ HandlerList (OpDescription::fromIntrinsic (intrinsicID))
89
+ .push_back (handlerIdx);
82
90
83
- for (const auto &dialectOpPair : opSet->getDialectOps ()) {
84
- m_opMap[ OpDescription::fromDialectOp (dialectOpPair.isOverload ,
85
- dialectOpPair.mnemonic )]
91
+ for (const auto &dialectOpPair : opSet->getDialectOps ())
92
+ HandlerList ( OpDescription::fromDialectOp (dialectOpPair.isOverload ,
93
+ dialectOpPair.mnemonic ))
86
94
.push_back (handlerIdx);
87
- }
88
95
}
89
96
}
90
97
98
+ void VisitorTemplate::add (VisitorKey key, VisitorCallback *fn,
99
+ VisitorCallbackData data,
100
+ VisitorHandler::Projection projection,
101
+ VisitorCallbackType visitorCallbackTy) {
102
+ assert (visitorCallbackTy != VisitorCallbackType::PreVisit || key.m_set );
103
+
104
+ VisitorHandler handler;
105
+ handler.callback = fn;
106
+ handler.data = data;
107
+ handler.projection = projection;
108
+
109
+ m_handlers.emplace_back (handler);
110
+
111
+ const unsigned handlerIdx = m_handlers.size () - 1 ;
112
+
113
+ storeHandlersInOpMap (key, handlerIdx, visitorCallbackTy);
114
+ }
115
+
91
116
VisitorBuilderBase::VisitorBuilderBase () : m_template(&m_ownedTemplate) {}
92
117
93
118
VisitorBuilderBase::VisitorBuilderBase (VisitorBuilderBase *parent,
@@ -144,6 +169,13 @@ void VisitorBuilderBase::setStrategy(VisitorStrategy strategy) {
144
169
m_template->setStrategy (strategy);
145
170
}
146
171
172
+ void VisitorBuilderBase::addPreVisitCallback (VisitorKey key,
173
+ VisitorCallback *fn,
174
+ VisitorCallbackData data) {
175
+ m_template->add (key, fn, data, m_projection,
176
+ VisitorTemplate::VisitorCallbackType::PreVisit);
177
+ }
178
+
147
179
void VisitorBuilderBase::add (VisitorKey key, VisitorCallback *fn,
148
180
VisitorCallbackData data) {
149
181
m_template->add (key, fn, data, m_projection);
@@ -192,9 +224,12 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)
192
224
BuildHelper helper (*this , templ.m_handlers );
193
225
194
226
m_opMap.reserve (templ.m_opMap );
195
-
196
- for (auto it : templ.m_opMap )
197
- m_opMap[it.first ] = helper.mapHandlers (it.second );
227
+ for (auto it : templ.m_opMap ) {
228
+ m_opMap[it.first ].PreVisitCallbacks =
229
+ helper.mapHandlers (it.second .PreVisitHandlers );
230
+ m_opMap[it.first ].VisitCallbacks =
231
+ helper.mapHandlers (it.second .VisitHandlers );
232
+ }
198
233
}
199
234
200
235
void VisitorBase::call (HandlerRange handlers, void *payload,
@@ -223,11 +258,14 @@ VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
223
258
}
224
259
225
260
void VisitorBase::visit (void *payload, Instruction &inst) const {
226
- auto handlers = m_opMap.find (inst);
227
- if (!handlers )
261
+ auto mappedHandlers = m_opMap.find (inst);
262
+ if (!mappedHandlers )
228
263
return ;
229
264
230
- call (*handlers.val (), payload, inst);
265
+ auto &callbacks = *mappedHandlers.val ();
266
+
267
+ call (callbacks.PreVisitCallbacks , payload, inst);
268
+ call (callbacks.VisitCallbacks , payload, inst);
231
269
}
232
270
233
271
template <typename FilterT>
@@ -241,19 +279,23 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module,
241
279
242
280
LLVM_DEBUG (dbgs () << " visit " << decl.getName () << ' \n ' );
243
281
244
- auto handlers = m_opMap.find (decl);
245
- if (!handlers ) {
282
+ auto mappedHandlers = m_opMap.find (decl);
283
+ if (!mappedHandlers ) {
246
284
// Neither a matched intrinsic nor a matched dialect op; skip.
247
285
continue ;
248
286
}
249
287
288
+ auto &callbacks = *mappedHandlers.val ();
289
+
250
290
for (Use &use : make_early_inc_range (decl.uses ())) {
251
291
if (auto *inst = dyn_cast<Instruction>(use.getUser ())) {
252
292
if (!filter (*inst))
253
293
continue ;
254
294
if (auto *callInst = dyn_cast<CallInst>(inst)) {
255
- if (&use == &callInst->getCalledOperandUse ())
256
- call (*handlers.val (), payload, *callInst);
295
+ if (&use == &callInst->getCalledOperandUse ()) {
296
+ call (callbacks.PreVisitCallbacks , payload, *callInst);
297
+ call (callbacks.VisitCallbacks , payload, *callInst);
298
+ }
257
299
}
258
300
}
259
301
}
0 commit comments