Skip to content

Commit d7f2314

Browse files
tsymallatsymalla-AMD
authored andcommitted
Implement support for pre-visitor callbacks.
This change implements a mechanism to run pre-visit callbacks on each operation it stumbles upon. This can be generalized for every op that is being visited, for the current nesting level, or for a given `OpSet`. The code simplifies writing visitors in cases where generic code is written for a set operations, like resetting the insertion point of an `IRBuilder`.
1 parent 3f9e17f commit d7f2314

File tree

5 files changed

+173
-41
lines changed

5 files changed

+173
-41
lines changed

example/ExampleMain.cpp

+39-4
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ std::unique_ptr<Module> createModuleExample(LLVMContext &context) {
152152

153153
struct VisitorInnermost {
154154
int counter = 0;
155+
raw_ostream *out = nullptr;
155156
};
156157

157158
struct VisitorNest {
@@ -177,6 +178,13 @@ struct llvm_dialects::VisitorPayloadProjection<VisitorNest, raw_ostream> {
177178
static raw_ostream &project(VisitorNest &nest) { return *nest.out; }
178179
};
179180

181+
template <>
182+
struct llvm_dialects::VisitorPayloadProjection<VisitorInnermost, raw_ostream> {
183+
static raw_ostream &project(VisitorInnermost &innerMost) {
184+
return *innerMost.out;
185+
}
186+
};
187+
180188
LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorContainer, nest)
181189
LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorNest, inner)
182190

@@ -215,8 +223,8 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
215223
b.addSet(complexSet, [](VisitorNest &self, llvm::Instruction &op) {
216224
assert((op.getOpcode() == Instruction::Ret ||
217225
(isa<IntrinsicInst>(&op) &&
218-
cast<IntrinsicInst>(&op)->getIntrinsicID() ==
219-
Intrinsic::umin)) &&
226+
cast<IntrinsicInst>(&op)->getIntrinsicID() ==
227+
Intrinsic::umin)) &&
220228
"Unexpected operation detected while visiting OpSet!");
221229

222230
if (op.getOpcode() == Instruction::Ret) {
@@ -249,10 +257,36 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
249257
Intrinsic::umax, [](raw_ostream &out, IntrinsicInst &umax) {
250258
out << "visiting umax intrinsic: " << umax << '\n';
251259
});
260+
b.addPreVisitCallback<xd::ReadOp, xd::WriteOp>(
261+
[](raw_ostream &out, llvm::Instruction &inst) {
262+
if (isa<xd::ReadOp>(inst))
263+
out << "Will visit ReadOp next: " << inst << '\n';
264+
else if (isa<xd::WriteOp>(inst))
265+
out << "Will visit WriteOp next: " << inst << '\n';
266+
else
267+
llvm_unreachable("Unexpected op!");
268+
});
269+
270+
b.addPreVisitCallback([](raw_ostream &out, Instruction &inst) {
271+
if (isa<IntrinsicInst>(inst))
272+
out << "Pre-visiting intrinsic instruction: " << inst << '\n';
273+
});
252274
});
253275
b.nest<VisitorInnermost>([](VisitorBuilder<VisitorInnermost> &b) {
254-
b.add<xd::ITruncOp>([](VisitorInnermost &inner,
255-
xd::ITruncOp &op) { inner.counter++; });
276+
b.add<xd::ITruncOp>(
277+
[](VisitorInnermost &inner, xd::ITruncOp &op) {
278+
inner.counter++;
279+
*inner.out
280+
<< "Counter after visiting ITruncOp: " << inner.counter
281+
<< '\n';
282+
});
283+
284+
b.addPreVisitCallback<xd::ITruncOp>(
285+
[](VisitorInnermost &inner, Instruction &op) {
286+
if (isa<xd::ITruncOp>(op))
287+
*inner.out << "Counter before visiting ITruncOp: "
288+
<< inner.counter << '\n';
289+
});
256290
});
257291
})
258292
.setStrategy(rpot ? VisitorStrategy::ReversePostOrder
@@ -267,6 +301,7 @@ void exampleVisit(Module &module) {
267301

268302
VisitorContainer container;
269303
container.nest.out = &outs();
304+
container.nest.inner.out = &outs();
270305
visitor.visit(container, module);
271306

272307
outs() << "inner.counter = " << container.nest.inner.counter << '\n';

include/llvm-dialects/Dialect/OpSet.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class OpSet final {
9191
// arguments.
9292
template <typename... OpTs> static const OpSet get() {
9393
static OpSet set;
94-
(... && appendT<OpTs>(set));
94+
(void)(... && appendT<OpTs>(set));
9595
return set;
9696
}
9797

@@ -153,6 +153,11 @@ class OpSet final {
153153
return isMatchingDialectOp(func.getName());
154154
}
155155

156+
bool empty() const {
157+
return m_coreOpcodes.empty() && m_intrinsicIDs.empty() &&
158+
m_dialectOps.empty();
159+
}
160+
156161
// -------------------------------------------------------------
157162
// Convenience getters to access the internal data structures.
158163
// -------------------------------------------------------------

include/llvm-dialects/Dialect/Visitor.h

+47-3
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,27 @@ class VisitorTemplate {
238238
friend class VisitorBuilderBase;
239239

240240
public:
241+
enum class VisitorCallbackType : uint8_t { PreVisit = 0, Visit = 1 };
242+
241243
void setStrategy(VisitorStrategy strategy);
242244
void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data,
243-
VisitorHandler::Projection projection);
245+
VisitorHandler::Projection projection,
246+
VisitorCallbackType visitorCallbackTy = VisitorCallbackType::Visit);
244247

245248
private:
249+
void storeHandlersInOpMap(const VisitorKey &key, unsigned handlerIdx,
250+
VisitorCallbackType callbackTy);
251+
246252
VisitorStrategy m_strategy = VisitorStrategy::Default;
247253
std::vector<PayloadProjection> m_projections;
248254
std::vector<VisitorHandler> m_handlers;
249-
OpMap<llvm::SmallVector<unsigned>> m_opMap;
255+
256+
struct Handlers {
257+
llvm::SmallVector<unsigned> PreVisitHandlers;
258+
llvm::SmallVector<unsigned> VisitHandlers;
259+
};
260+
261+
OpMap<Handlers> m_opMap;
250262
};
251263

252264
/// @brief Base class for VisitorBuilders
@@ -279,6 +291,9 @@ class VisitorBuilderBase {
279291

280292
void setStrategy(VisitorStrategy strategy);
281293

294+
void addPreVisitCallback(VisitorKey key, VisitorCallback *fn,
295+
VisitorCallbackData data);
296+
282297
void add(VisitorKey key, VisitorCallback *fn, VisitorCallbackData data);
283298

284299
VisitorBase build();
@@ -307,6 +322,11 @@ class VisitorBase {
307322
class BuildHelper;
308323
using HandlerRange = std::pair<unsigned, unsigned>;
309324

325+
struct MappedHandlers {
326+
HandlerRange PreVisitCallbacks;
327+
HandlerRange VisitCallbacks;
328+
};
329+
310330
void call(HandlerRange handlers, void *payload,
311331
llvm::Instruction &inst) const;
312332
VisitorResult call(const VisitorHandler &handler, void *payload,
@@ -319,7 +339,7 @@ class VisitorBase {
319339
VisitorStrategy m_strategy;
320340
std::vector<PayloadProjection> m_projections;
321341
std::vector<VisitorHandler> m_handlers;
322-
OpMap<HandlerRange> m_opMap;
342+
OpMap<MappedHandlers> m_opMap;
323343
};
324344

325345
} // namespace detail
@@ -386,6 +406,20 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
386406
return *this;
387407
}
388408

409+
VisitorBuilder &
410+
addPreVisitCallback(const OpSet &opSet,
411+
VisitorResult (*fn)(PayloadT &, llvm::Instruction &I)) {
412+
addPreVisitCase(detail::VisitorKey::opSet(opSet), fn);
413+
return *this;
414+
}
415+
416+
template <typename... OpTs>
417+
VisitorBuilder &addPreVisitCallback(void (*fn)(PayloadT &,
418+
llvm::Instruction &I)) {
419+
addPreVisitCase(detail::VisitorKey::opSet<OpTs...>(), fn);
420+
return *this;
421+
}
422+
389423
Visitor<PayloadT> build() { return VisitorBuilderBase::build(); }
390424

391425
template <typename OpT>
@@ -510,6 +544,16 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
510544
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder<ReturnT>, data);
511545
}
512546

547+
template <typename ReturnT>
548+
void addPreVisitCase(detail::VisitorKey key,
549+
ReturnT (*fn)(PayloadT &, llvm::Instruction &)) {
550+
detail::VisitorCallbackData data{};
551+
static_assert(sizeof(fn) <= sizeof(data.data));
552+
memcpy(&data.data, &fn, sizeof(fn));
553+
VisitorBuilderBase::addPreVisitCallback(
554+
key, &VisitorBuilder::setForwarder<ReturnT>, data);
555+
}
556+
513557
template <typename OpT, typename ReturnT>
514558
void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) {
515559
detail::VisitorCallbackData data{};

lib/Dialect/Visitor.cpp

+74-32
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "llvm/ADT/PostOrderIterator.h"
2020
#include "llvm/IR/Function.h"
2121
#include "llvm/IR/Instructions.h"
22-
#include "llvm/IR/IntrinsicInst.h"
2322
#include "llvm/IR/Module.h"
2423
#include "llvm/Support/Debug.h"
2524

@@ -44,50 +43,76 @@ void VisitorTemplate::setStrategy(VisitorStrategy strategy) {
4443
m_strategy = strategy;
4544
}
4645

47-
void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
48-
VisitorCallbackData data,
49-
VisitorHandler::Projection projection) {
50-
VisitorHandler handler;
51-
handler.callback = fn;
52-
handler.data = data;
53-
handler.projection = projection;
54-
55-
m_handlers.emplace_back(handler);
46+
void VisitorTemplate::storeHandlersInOpMap(
47+
const VisitorKey &key, unsigned handlerIdx,
48+
VisitorCallbackType visitorCallbackTy) {
49+
const auto HandlerList =
50+
[&](const OpDescription &opDescription) -> llvm::SmallVector<unsigned> & {
51+
if (visitorCallbackTy == VisitorCallbackType::PreVisit)
52+
return m_opMap[opDescription].PreVisitHandlers;
5653

57-
const unsigned handlerIdx = m_handlers.size() - 1;
54+
return m_opMap[opDescription].VisitHandlers;
55+
};
5856

5957
if (key.m_kind == VisitorKey::Kind::Intrinsic) {
60-
m_opMap[OpDescription::fromIntrinsic(key.m_intrinsicId)].push_back(
61-
handlerIdx);
58+
HandlerList(OpDescription::fromIntrinsic(key.m_intrinsicId))
59+
.push_back(handlerIdx);
6260
} else if (key.m_kind == VisitorKey::Kind::OpDescription) {
6361
const OpDescription *opDesc = key.m_description;
6462

6563
if (opDesc->isCoreOp()) {
6664
for (const unsigned op : opDesc->getOpcodes())
67-
m_opMap[OpDescription::fromCoreOp(op)].push_back(handlerIdx);
65+
HandlerList(OpDescription::fromCoreOp(op)).push_back(handlerIdx);
6866
} else if (opDesc->isIntrinsic()) {
6967
for (const unsigned op : opDesc->getOpcodes())
70-
m_opMap[OpDescription::fromIntrinsic(op)].push_back(handlerIdx);
68+
HandlerList(OpDescription::fromIntrinsic(op)).push_back(handlerIdx);
7169
} else {
72-
m_opMap[*opDesc].push_back(handlerIdx);
70+
HandlerList(*opDesc).push_back(handlerIdx);
7371
}
7472
} else if (key.m_kind == VisitorKey::Kind::OpSet) {
7573
const OpSet *opSet = key.m_set;
7674

75+
if (visitorCallbackTy == VisitorCallbackType::PreVisit && opSet->empty()) {
76+
// This adds a handler for every stored op.
77+
// Note: should be used with caution.
78+
for (auto it : m_opMap)
79+
it.second.PreVisitHandlers.push_back(handlerIdx);
80+
81+
return;
82+
}
83+
7784
for (unsigned opcode : opSet->getCoreOpcodes())
78-
m_opMap[OpDescription::fromCoreOp(opcode)].push_back(handlerIdx);
85+
HandlerList(OpDescription::fromCoreOp(opcode)).push_back(handlerIdx);
7986

8087
for (unsigned intrinsicID : opSet->getIntrinsicIDs())
81-
m_opMap[OpDescription::fromIntrinsic(intrinsicID)].push_back(handlerIdx);
88+
HandlerList(OpDescription::fromIntrinsic(intrinsicID))
89+
.push_back(handlerIdx);
8290

83-
for (const auto &dialectOpPair : opSet->getDialectOps()) {
84-
m_opMap[OpDescription::fromDialectOp(dialectOpPair.isOverload,
85-
dialectOpPair.mnemonic)]
91+
for (const auto &dialectOpPair : opSet->getDialectOps())
92+
HandlerList(OpDescription::fromDialectOp(dialectOpPair.isOverload,
93+
dialectOpPair.mnemonic))
8694
.push_back(handlerIdx);
87-
}
8895
}
8996
}
9097

98+
void VisitorTemplate::add(VisitorKey key, VisitorCallback *fn,
99+
VisitorCallbackData data,
100+
VisitorHandler::Projection projection,
101+
VisitorCallbackType visitorCallbackTy) {
102+
assert(visitorCallbackTy != VisitorCallbackType::PreVisit || key.m_set);
103+
104+
VisitorHandler handler;
105+
handler.callback = fn;
106+
handler.data = data;
107+
handler.projection = projection;
108+
109+
m_handlers.emplace_back(handler);
110+
111+
const unsigned handlerIdx = m_handlers.size() - 1;
112+
113+
storeHandlersInOpMap(key, handlerIdx, visitorCallbackTy);
114+
}
115+
91116
VisitorBuilderBase::VisitorBuilderBase() : m_template(&m_ownedTemplate) {}
92117

93118
VisitorBuilderBase::VisitorBuilderBase(VisitorBuilderBase *parent,
@@ -144,6 +169,13 @@ void VisitorBuilderBase::setStrategy(VisitorStrategy strategy) {
144169
m_template->setStrategy(strategy);
145170
}
146171

172+
void VisitorBuilderBase::addPreVisitCallback(VisitorKey key,
173+
VisitorCallback *fn,
174+
VisitorCallbackData data) {
175+
m_template->add(key, fn, data, m_projection,
176+
VisitorTemplate::VisitorCallbackType::PreVisit);
177+
}
178+
147179
void VisitorBuilderBase::add(VisitorKey key, VisitorCallback *fn,
148180
VisitorCallbackData data) {
149181
m_template->add(key, fn, data, m_projection);
@@ -192,9 +224,12 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)
192224
BuildHelper helper(*this, templ.m_handlers);
193225

194226
m_opMap.reserve(templ.m_opMap);
195-
196-
for (auto it : templ.m_opMap)
197-
m_opMap[it.first] = helper.mapHandlers(it.second);
227+
for (auto it : templ.m_opMap) {
228+
m_opMap[it.first].PreVisitCallbacks =
229+
helper.mapHandlers(it.second.PreVisitHandlers);
230+
m_opMap[it.first].VisitCallbacks =
231+
helper.mapHandlers(it.second.VisitHandlers);
232+
}
198233
}
199234

200235
void VisitorBase::call(HandlerRange handlers, void *payload,
@@ -223,11 +258,14 @@ VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
223258
}
224259

225260
void VisitorBase::visit(void *payload, Instruction &inst) const {
226-
auto handlers = m_opMap.find(inst);
227-
if (!handlers)
261+
auto mappedHandlers = m_opMap.find(inst);
262+
if (!mappedHandlers)
228263
return;
229264

230-
call(*handlers.val(), payload, inst);
265+
auto &callbacks = *mappedHandlers.val();
266+
267+
call(callbacks.PreVisitCallbacks, payload, inst);
268+
call(callbacks.VisitCallbacks, payload, inst);
231269
}
232270

233271
template <typename FilterT>
@@ -241,19 +279,23 @@ void VisitorBase::visitByDeclarations(void *payload, llvm::Module &module,
241279

242280
LLVM_DEBUG(dbgs() << "visit " << decl.getName() << '\n');
243281

244-
auto handlers = m_opMap.find(decl);
245-
if (!handlers) {
282+
auto mappedHandlers = m_opMap.find(decl);
283+
if (!mappedHandlers) {
246284
// Neither a matched intrinsic nor a matched dialect op; skip.
247285
continue;
248286
}
249287

288+
auto &callbacks = *mappedHandlers.val();
289+
250290
for (Use &use : make_early_inc_range(decl.uses())) {
251291
if (auto *inst = dyn_cast<Instruction>(use.getUser())) {
252292
if (!filter(*inst))
253293
continue;
254294
if (auto *callInst = dyn_cast<CallInst>(inst)) {
255-
if (&use == &callInst->getCalledOperandUse())
256-
call(*handlers.val(), payload, *callInst);
295+
if (&use == &callInst->getCalledOperandUse()) {
296+
call(callbacks.PreVisitCallbacks, payload, *callInst);
297+
call(callbacks.VisitCallbacks, payload, *callInst);
298+
}
257299
}
258300
}
259301
}

0 commit comments

Comments
 (0)