diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 1285598a1c0282..f2056de87cb946 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -9,18 +9,39 @@ #ifndef LLVM_SANDBOXIR_CONTEXT_H #define LLVM_SANDBOXIR_CONTEXT_H +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/LLVMContext.h" #include "llvm/SandboxIR/Tracker.h" #include "llvm/SandboxIR/Type.h" +#include + namespace llvm::sandboxir { -class Module; -class Value; class Argument; +class BBIterator; class Constant; +class Module; +class Value; class Context { +public: + // A EraseInstrCallback receives the instruction about to be erased. + using EraseInstrCallback = std::function; + // A CreateInstrCallback receives the instruction about to be created. + using CreateInstrCallback = std::function; + // A MoveInstrCallback receives the instruction about to be moved, the + // destination BB and an iterator pointing to the insertion position. + using MoveInstrCallback = + std::function; + + /// An ID for a registered callback. Used for deregistration. Using a 64-bit + /// integer so we don't have to worry about the unlikely case of overflowing + /// a 32-bit counter. + using CallbackID = uint64_t; + protected: LLVMContext &LLVMCtx; friend class Type; // For LLVMCtx. @@ -48,6 +69,21 @@ class Context { /// Type objects. DenseMap> LLVMTypeToTypeMap; + /// Callbacks called when an IR instruction is about to get erased. Keys are + /// used as IDs for deregistration. + MapVector EraseInstrCallbacks; + /// Callbacks called when an IR instruction is about to get created. Keys are + /// used as IDs for deregistration. + MapVector CreateInstrCallbacks; + /// Callbacks called when an IR instruction is about to get moved. Keys are + /// used as IDs for deregistration. + MapVector MoveInstrCallbacks; + + /// A counter used for assigning callback IDs during registration. The same + /// counter is used for all kinds of callbacks so we can detect mismatched + /// registration/deregistration. + CallbackID NextCallbackID = 0; + /// Remove \p V from the maps and returns the unique_ptr. std::unique_ptr detachLLVMValue(llvm::Value *V); /// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively @@ -70,6 +106,10 @@ class Context { Constant *getOrCreateConstant(llvm::Constant *LLVMC); friend class Utils; // For getMemoryBase + void runEraseInstrCallbacks(Instruction *I); + void runCreateInstrCallbacks(Instruction *I); + void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where); + // Friends for getOrCreateConstant(). #define DEF_CONST(ID, CLASS) friend class CLASS; #include "llvm/SandboxIR/Values.def" @@ -198,6 +238,28 @@ class Context { /// \Returns the number of values registered with Context. size_t getNumValues() const { return LLVMValueToValueMap.size(); } + + /// Register a callback that gets called when a SandboxIR instruction is about + /// to be removed from its parent. Note that this will also be called when + /// reverting the creation of an instruction. + /// \Returns a callback ID for later deregistration. + CallbackID registerEraseInstrCallback(EraseInstrCallback CB); + void unregisterEraseInstrCallback(CallbackID ID); + + /// Register a callback that gets called right after a SandboxIR instruction + /// is created. Note that this will also be called when reverting the removal + /// of an instruction. + /// \Returns a callback ID for later deregistration. + CallbackID registerCreateInstrCallback(CreateInstrCallback CB); + void unregisterCreateInstrCallback(CallbackID ID); + + /// Register a callback that gets called when a SandboxIR instruction is about + /// to be moved. Note that this will also be called when reverting a move. + /// \Returns a callback ID for later deregistration. + CallbackID registerMoveInstrCallback(MoveInstrCallback CB); + void unregisterMoveInstrCallback(CallbackID ID); + + // TODO: Add callbacks for instructions inserted/removed if needed. }; } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 486e935bc35fba..5e5cbbbc4515d2 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr &&VPtr) { assert(VPtr->getSubclassID() != Value::ClassID::User && "Can't register a user!"); + Value *V = VPtr.get(); + [[maybe_unused]] auto Pair = + LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)}); + assert(Pair.second && "Already exists!"); + // Track creation of instructions. // Please note that we don't allow the creation of detached instructions, // meaning that the instructions need to be inserted into a block upon // creation. This is why the tracker class combines creation and insertion. - if (auto *I = dyn_cast(VPtr.get())) + if (auto *I = dyn_cast(V)) { getTracker().emplaceIfTracking(I); + runCreateInstrCallbacks(I); + } - Value *V = VPtr.get(); - [[maybe_unused]] auto Pair = - LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)}); - assert(Pair.second && "Already exists!"); return V; } @@ -660,4 +663,64 @@ Module *Context::createModule(llvm::Module *LLVMM) { return M; } +void Context::runEraseInstrCallbacks(Instruction *I) { + for (const auto &CBEntry : EraseInstrCallbacks) + CBEntry.second(I); +} + +void Context::runCreateInstrCallbacks(Instruction *I) { + for (auto &CBEntry : CreateInstrCallbacks) + CBEntry.second(I); +} + +void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { + for (auto &CBEntry : MoveInstrCallbacks) + CBEntry.second(I, WhereIt); +} + +// An arbitrary limit, to check for accidental misuse. We expect a small number +// of callbacks to be registered at a time, but we can increase this number if +// we discover we needed more. +static constexpr int MaxRegisteredCallbacks = 16; + +Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) { + assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks && + "EraseInstrCallbacks size limit exceeded"); + CallbackID ID = NextCallbackID++; + EraseInstrCallbacks[ID] = CB; + return ID; +} +void Context::unregisterEraseInstrCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID); + assert(Erased && + "Callback ID not found in EraseInstrCallbacks during deregistration"); +} + +Context::CallbackID +Context::registerCreateInstrCallback(CreateInstrCallback CB) { + assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks && + "CreateInstrCallbacks size limit exceeded"); + CallbackID ID = NextCallbackID++; + CreateInstrCallbacks[ID] = CB; + return ID; +} +void Context::unregisterCreateInstrCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID); + assert(Erased && + "Callback ID not found in CreateInstrCallbacks during deregistration"); +} + +Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) { + assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks && + "MoveInstrCallbacks size limit exceeded"); + CallbackID ID = NextCallbackID++; + MoveInstrCallbacks[ID] = CB; + return ID; +} +void Context::unregisterMoveInstrCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID); + assert(Erased && + "Callback ID not found in MoveInstrCallbacks during deregistration"); +} + } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp index d80d10370e32d8..096b827541eeaf 100644 --- a/llvm/lib/SandboxIR/Instruction.cpp +++ b/llvm/lib/SandboxIR/Instruction.cpp @@ -73,6 +73,8 @@ void Instruction::removeFromParent() { void Instruction::eraseFromParent() { assert(users().empty() && "Still connected to users, can't erase!"); + + Ctx.runEraseInstrCallbacks(this); std::unique_ptr Detached = Ctx.detach(this); auto LLVMInstrs = getLLVMInstrs(); @@ -100,6 +102,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) { // Destination is same as origin, nothing to do. return; + Ctx.runMoveInstrCallbacks(this, WhereIt); Ctx.getTracker().emplaceIfTracking(this); auto *LLVMBB = cast(BB.Val); diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 97113b303f72e5..99e14292a91b92 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -22,6 +22,7 @@ #include "llvm/SandboxIR/Value.h" #include "llvm/Support/SourceMgr.h" #include "gmock/gmock-matchers.h" +#include "gmock/gmock-more-matchers.h" #include "gtest/gtest.h" using namespace llvm; @@ -5962,3 +5963,100 @@ TEST_F(SandboxIRTest, CheckClassof) { EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof); #include "llvm/SandboxIR/Values.def" } + +TEST_F(SandboxIRTest, InstructionCallbacks) { + parseIR(C, R"IR( + define void @foo(ptr %ptr, i8 %val) { + ret void + } + )IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + sandboxir::Argument *Ptr = F.getArg(0); + sandboxir::Argument *Val = F.getArg(1); + sandboxir::Instruction *Ret = &BB.front(); + + SmallVector Inserted; + auto InsertCbId = Ctx.registerCreateInstrCallback( + [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); }); + + SmallVector Removed; + auto RemoveCbId = Ctx.registerEraseInstrCallback( + [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); }); + + // Keep the moved instruction and the instruction pointed by the Where + // iterator so we can check both callback arguments work as expected. + SmallVector> + Moved; + auto MoveCbId = Ctx.registerMoveInstrCallback( + [&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) { + // Use a nullptr to signal "move to end" to keep it single. We only + // have a basic block in this test case anyway. + if (Where == Where.getNodeParent()->end()) + Moved.push_back(std::make_pair(I, nullptr)); + else + Moved.push_back(std::make_pair(I, &*Where)); + }); + + // Two more insertion callbacks, to check that they're called in registration + // order. + SmallVector Order; + auto CheckOrderInsertCbId1 = Ctx.registerCreateInstrCallback( + [&Order](sandboxir::Instruction *I) { Order.push_back(1); }); + + auto CheckOrderInsertCbId2 = Ctx.registerCreateInstrCallback( + [&Order](sandboxir::Instruction *I) { Order.push_back(2); }); + + Ctx.save(); + auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, + Ret->getIterator(), Ctx); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::IsEmpty()); + EXPECT_THAT(Moved, testing::IsEmpty()); + EXPECT_THAT(Order, testing::ElementsAre(1, 2)); + + Ret->moveBefore(NewI); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::IsEmpty()); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI))); + + Ret->eraseFromParent(); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::ElementsAre(Ret)); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI))); + + NewI->eraseFromParent(); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI)); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI))); + + // Check that after revert the callbacks have been called for the inverse + // operations of the changes made so far. + Ctx.revert(); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret)); + EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI)); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI), + std::make_pair(Ret, nullptr))); + EXPECT_THAT(Order, testing::ElementsAre(1, 2, 1, 2, 1, 2)); + + // Check that deregistration works. Do an operation of each type after + // deregistering callbacks and check. + Inserted.clear(); + Removed.clear(); + Moved.clear(); + Ctx.unregisterCreateInstrCallback(InsertCbId); + Ctx.unregisterEraseInstrCallback(RemoveCbId); + Ctx.unregisterMoveInstrCallback(MoveCbId); + Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId1); + Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId2); + auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, + Ret->getIterator(), Ctx); + Ret->moveBefore(NewI2); + Ret->eraseFromParent(); + EXPECT_THAT(Inserted, testing::IsEmpty()); + EXPECT_THAT(Removed, testing::IsEmpty()); + EXPECT_THAT(Moved, testing::IsEmpty()); +}