Skip to content

Commit

Permalink
[SandboxIR] Add callbacks for instruction insert/remove/move ops (llv…
Browse files Browse the repository at this point in the history
  • Loading branch information
slackito authored Oct 29, 2024
1 parent a9c417c commit 4df71ab
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 7 deletions.
66 changes: 64 additions & 2 deletions llvm/include/llvm/SandboxIR/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,39 @@
#ifndef LLVM_SANDBOXIR_CONTEXT_H
#define LLVM_SANDBOXIR_CONTEXT_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/SandboxIR/Tracker.h"
#include "llvm/SandboxIR/Type.h"

#include <cstdint>

namespace llvm::sandboxir {

class Module;
class Value;
class Argument;
class BBIterator;
class Constant;
class Module;
class Value;

class Context {
public:
// A EraseInstrCallback receives the instruction about to be erased.
using EraseInstrCallback = std::function<void(Instruction *)>;
// A CreateInstrCallback receives the instruction about to be created.
using CreateInstrCallback = std::function<void(Instruction *)>;
// A MoveInstrCallback receives the instruction about to be moved, the
// destination BB and an iterator pointing to the insertion position.
using MoveInstrCallback =
std::function<void(Instruction *, const BBIterator &)>;

/// An ID for a registered callback. Used for deregistration. Using a 64-bit
/// integer so we don't have to worry about the unlikely case of overflowing
/// a 32-bit counter.
using CallbackID = uint64_t;

protected:
LLVMContext &LLVMCtx;
friend class Type; // For LLVMCtx.
Expand Down Expand Up @@ -48,6 +69,21 @@ class Context {
/// Type objects.
DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;

/// Callbacks called when an IR instruction is about to get erased. Keys are
/// used as IDs for deregistration.
MapVector<CallbackID, EraseInstrCallback> EraseInstrCallbacks;
/// Callbacks called when an IR instruction is about to get created. Keys are
/// used as IDs for deregistration.
MapVector<CallbackID, CreateInstrCallback> CreateInstrCallbacks;
/// Callbacks called when an IR instruction is about to get moved. Keys are
/// used as IDs for deregistration.
MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks;

/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
/// registration/deregistration.
CallbackID NextCallbackID = 0;

/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
Expand All @@ -70,6 +106,10 @@ class Context {
Constant *getOrCreateConstant(llvm::Constant *LLVMC);
friend class Utils; // For getMemoryBase

void runEraseInstrCallbacks(Instruction *I);
void runCreateInstrCallbacks(Instruction *I);
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);

// Friends for getOrCreateConstant().
#define DEF_CONST(ID, CLASS) friend class CLASS;
#include "llvm/SandboxIR/Values.def"
Expand Down Expand Up @@ -198,6 +238,28 @@ class Context {

/// \Returns the number of values registered with Context.
size_t getNumValues() const { return LLVMValueToValueMap.size(); }

/// Register a callback that gets called when a SandboxIR instruction is about
/// to be removed from its parent. Note that this will also be called when
/// reverting the creation of an instruction.
/// \Returns a callback ID for later deregistration.
CallbackID registerEraseInstrCallback(EraseInstrCallback CB);
void unregisterEraseInstrCallback(CallbackID ID);

/// Register a callback that gets called right after a SandboxIR instruction
/// is created. Note that this will also be called when reverting the removal
/// of an instruction.
/// \Returns a callback ID for later deregistration.
CallbackID registerCreateInstrCallback(CreateInstrCallback CB);
void unregisterCreateInstrCallback(CallbackID ID);

/// Register a callback that gets called when a SandboxIR instruction is about
/// to be moved. Note that this will also be called when reverting a move.
/// \Returns a callback ID for later deregistration.
CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
void unregisterMoveInstrCallback(CallbackID ID);

// TODO: Add callbacks for instructions inserted/removed if needed.
};

} // namespace llvm::sandboxir
Expand Down
73 changes: 68 additions & 5 deletions llvm/lib/SandboxIR/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
assert(VPtr->getSubclassID() != Value::ClassID::User &&
"Can't register a user!");

Value *V = VPtr.get();
[[maybe_unused]] auto Pair =
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
assert(Pair.second && "Already exists!");

// Track creation of instructions.
// Please note that we don't allow the creation of detached instructions,
// meaning that the instructions need to be inserted into a block upon
// creation. This is why the tracker class combines creation and insertion.
if (auto *I = dyn_cast<Instruction>(VPtr.get()))
if (auto *I = dyn_cast<Instruction>(V)) {
getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
runCreateInstrCallbacks(I);
}

Value *V = VPtr.get();
[[maybe_unused]] auto Pair =
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
assert(Pair.second && "Already exists!");
return V;
}

Expand Down Expand Up @@ -660,4 +663,64 @@ Module *Context::createModule(llvm::Module *LLVMM) {
return M;
}

void Context::runEraseInstrCallbacks(Instruction *I) {
for (const auto &CBEntry : EraseInstrCallbacks)
CBEntry.second(I);
}

void Context::runCreateInstrCallbacks(Instruction *I) {
for (auto &CBEntry : CreateInstrCallbacks)
CBEntry.second(I);
}

void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
for (auto &CBEntry : MoveInstrCallbacks)
CBEntry.second(I, WhereIt);
}

// An arbitrary limit, to check for accidental misuse. We expect a small number
// of callbacks to be registered at a time, but we can increase this number if
// we discover we needed more.
static constexpr int MaxRegisteredCallbacks = 16;

Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
"EraseInstrCallbacks size limit exceeded");
CallbackID ID = NextCallbackID++;
EraseInstrCallbacks[ID] = CB;
return ID;
}
void Context::unregisterEraseInstrCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in EraseInstrCallbacks during deregistration");
}

Context::CallbackID
Context::registerCreateInstrCallback(CreateInstrCallback CB) {
assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
"CreateInstrCallbacks size limit exceeded");
CallbackID ID = NextCallbackID++;
CreateInstrCallbacks[ID] = CB;
return ID;
}
void Context::unregisterCreateInstrCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in CreateInstrCallbacks during deregistration");
}

Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
"MoveInstrCallbacks size limit exceeded");
CallbackID ID = NextCallbackID++;
MoveInstrCallbacks[ID] = CB;
return ID;
}
void Context::unregisterMoveInstrCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in MoveInstrCallbacks during deregistration");
}

} // namespace llvm::sandboxir
3 changes: 3 additions & 0 deletions llvm/lib/SandboxIR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ void Instruction::removeFromParent() {

void Instruction::eraseFromParent() {
assert(users().empty() && "Still connected to users, can't erase!");

Ctx.runEraseInstrCallbacks(this);
std::unique_ptr<Value> Detached = Ctx.detach(this);
auto LLVMInstrs = getLLVMInstrs();

Expand Down Expand Up @@ -100,6 +102,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
// Destination is same as origin, nothing to do.
return;

Ctx.runMoveInstrCallbacks(this, WhereIt);
Ctx.getTracker().emplaceIfTracking<MoveInstr>(this);

auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);
Expand Down
98 changes: 98 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/SourceMgr.h"
#include "gmock/gmock-matchers.h"
#include "gmock/gmock-more-matchers.h"
#include "gtest/gtest.h"

using namespace llvm;
Expand Down Expand Up @@ -5962,3 +5963,100 @@ TEST_F(SandboxIRTest, CheckClassof) {
EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof);
#include "llvm/SandboxIR/Values.def"
}

TEST_F(SandboxIRTest, InstructionCallbacks) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %val) {
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
sandboxir::Argument *Ptr = F.getArg(0);
sandboxir::Argument *Val = F.getArg(1);
sandboxir::Instruction *Ret = &BB.front();

SmallVector<sandboxir::Instruction *> Inserted;
auto InsertCbId = Ctx.registerCreateInstrCallback(
[&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });

SmallVector<sandboxir::Instruction *> Removed;
auto RemoveCbId = Ctx.registerEraseInstrCallback(
[&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });

// Keep the moved instruction and the instruction pointed by the Where
// iterator so we can check both callback arguments work as expected.
SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
Moved;
auto MoveCbId = Ctx.registerMoveInstrCallback(
[&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
// Use a nullptr to signal "move to end" to keep it single. We only
// have a basic block in this test case anyway.
if (Where == Where.getNodeParent()->end())
Moved.push_back(std::make_pair(I, nullptr));
else
Moved.push_back(std::make_pair(I, &*Where));
});

// Two more insertion callbacks, to check that they're called in registration
// order.
SmallVector<int> Order;
auto CheckOrderInsertCbId1 = Ctx.registerCreateInstrCallback(
[&Order](sandboxir::Instruction *I) { Order.push_back(1); });

auto CheckOrderInsertCbId2 = Ctx.registerCreateInstrCallback(
[&Order](sandboxir::Instruction *I) { Order.push_back(2); });

Ctx.save();
auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
Ret->getIterator(), Ctx);
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::IsEmpty());
EXPECT_THAT(Order, testing::ElementsAre(1, 2));

Ret->moveBefore(NewI);
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));

Ret->eraseFromParent();
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::ElementsAre(Ret));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));

NewI->eraseFromParent();
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));

// Check that after revert the callbacks have been called for the inverse
// operations of the changes made so far.
Ctx.revert();
EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret));
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
std::make_pair(Ret, nullptr)));
EXPECT_THAT(Order, testing::ElementsAre(1, 2, 1, 2, 1, 2));

// Check that deregistration works. Do an operation of each type after
// deregistering callbacks and check.
Inserted.clear();
Removed.clear();
Moved.clear();
Ctx.unregisterCreateInstrCallback(InsertCbId);
Ctx.unregisterEraseInstrCallback(RemoveCbId);
Ctx.unregisterMoveInstrCallback(MoveCbId);
Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId1);
Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId2);
auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
Ret->getIterator(), Ctx);
Ret->moveBefore(NewI2);
Ret->eraseFromParent();
EXPECT_THAT(Inserted, testing::IsEmpty());
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::IsEmpty());
}

0 comments on commit 4df71ab

Please sign in to comment.