diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 231b241c5e..5561955304 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -276,6 +276,10 @@ void addKrnlToLLVMPasses( // pm.addNestedPass(krnl::createConvertSeqToMemrefPass()); pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass()); + + if (profileIR) + pm.addNestedPass(onnx_mlir::createInstrumentCleanupPass()); + if (enableBoundCheck) pm.addPass(mlir::createGenerateRuntimeVerificationPass()); pm.addPass(krnl::createConvertKrnlToLLVMPass(verifyInputTensors, diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 3062e0abe8..f22cbf2595 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -57,6 +57,8 @@ std::unique_ptr createConstPropONNXToONNXPass(); std::unique_ptr createInstrumentPass(); std::unique_ptr createInstrumentPass( const std::string &ops, unsigned actions); +/// Pass for instrument cleanup. +std::unique_ptr createInstrumentCleanupPass(); /// Passes for instrumenting the ONNX ops to print their operand type /// signatures at runtime. diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index deb0783bd0..671dea1857 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -70,6 +70,10 @@ void registerOMPasses(int optLevel) { mlir::registerPass( []() -> std::unique_ptr { return createInstrumentPass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return createInstrumentCleanupPass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return createInstrumentONNXSignaturePass("NONE"); }); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index cc51752de0..c09d9fb571 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -20,6 +20,7 @@ add_onnx_mlir_library(OMScfParallelPrivateRegion add_onnx_mlir_library(OMInstrument InstrumentPass.cpp + InstrumentCleanupPass.cpp INCLUDE_DIRS PUBLIC ${ONNX_MLIR_SRC_ROOT}/include diff --git a/src/Transform/InstrumentCleanupPass.cpp b/src/Transform/InstrumentCleanupPass.cpp new file mode 100644 index 0000000000..8cc8a94609 --- /dev/null +++ b/src/Transform/InstrumentCleanupPass.cpp @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------- InstrumentCleanupPass.cpp - Instrumentation -----------------===// +// +// Copyright 2025 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a Function level pass that remove consecutive +// instrumentation operations (first with "before" tag and second with "after") +// as they do not measure anything. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "onnx-mlir/Compiler/OMCompilerRuntimeTypes.h" +#include "onnx-mlir/Compiler/OMCompilerTypes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/raw_ostream.h" + +#include "src/Compiler/OptionUtils.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Interface/ShapeInferenceOpInterface.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +/*! + * This pass insert KrnlInstrumentOp before and after each ops + */ + +class InstrumentCleanupPass : public mlir::PassWrapper> { + +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InstrumentCleanupPass) + + InstrumentCleanupPass(){}; + InstrumentCleanupPass(const InstrumentCleanupPass &pass) + : mlir::PassWrapper>() {} + +private: +public: + StringRef getArgument() const override { return "instrument-cleanup"; } + + StringRef getDescription() const override { + return "instrument cleanup on ops."; + } + + void runOnOperation() override { + llvm::SmallVector eraseOpList; + bool skipNext = false; + + // Iterate on the operations nested in this function + getOperation().walk([&](mlir::Operation *op) -> WalkResult { + if (skipNext) { + skipNext = false; + return WalkResult::advance(); + } + KrnlInstrumentOp firstInstrOp = mlir::dyn_cast(op); + // Check if we have a first instrumentation op with instr before. + if (!firstInstrOp) + return WalkResult::advance(); + uint64_t firstTag = firstInstrOp.getTag(); + // skip if not before, or if this call initializes the instrumentation. + if (!IS_INSTRUMENT_BEFORE_OP(firstTag) || IS_INSTRUMENT_INIT(firstTag)) + return WalkResult::advance(); + // Check if we have a second instrumentation op with instr after. + Operation *nextOp = op->getNextNode(); + if (!nextOp) + return WalkResult::advance(); + KrnlInstrumentOp secondInstrOp = mlir::dyn_cast(nextOp); + if (!secondInstrOp) + return WalkResult::advance(); + uint64_t secondTag = secondInstrOp.getTag(); + // skip if not after, or if this call initializes the instrumentation. + if (!IS_INSTRUMENT_AFTER_OP(secondTag) || IS_INSTRUMENT_INIT(secondTag)) + return WalkResult::advance(); + // Could check opName but we already have a before/after pair, it can only + // be of the same op. + // Schedule both instrumentation to be removed as there is nothing between + // the start and the stop of the instrumentation. + eraseOpList.emplace_back(op); + eraseOpList.emplace_back(nextOp); + skipNext = true; + return WalkResult::advance(); + }); + // Remove ops. + for (Operation *op : eraseOpList) + op->erase(); + } +}; +} // namespace onnx_mlir + +/*! + * Create an instrumentation pass. + */ +std::unique_ptr onnx_mlir::createInstrumentCleanupPass() { + return std::make_unique(); +}