From 486f83faa31ae5356523da868a557619601a0e3e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 5 Jan 2025 17:32:07 +0100 Subject: [PATCH] [mlir][Transforms][NFC] Simplify `buildUnresolvedMaterialization` implementation (#121651) The `buildUnresolvedMaterialization` implementation used to check if a materialization is necessary. A materialization is not necessary if the desired types already match the input. However, this situation can never happen: we look for mapped values with the desired type at the call sites before requesting a new unresolved materialization. The previous implementation seemed incorrect because `buildUnresolvedMaterialization` created a mapping that is never rolled back. (When in reality that code was never executed, so it is technically not incorrect.) Also fix a comment that in `findOrBuildReplacementValue` that was incorrect. --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 48b8c727a78285..8296c0c468b017 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1430,13 +1430,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( UnrealizedConversionCastOp *castOp) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); - - // Avoid materializing an unnecessary cast. - if (TypeRange(inputs) == outputTypes) { - if (!valuesToMap.empty()) - mapping.map(std::move(valuesToMap), inputs); - return inputs; - } + assert(TypeRange(inputs) != outputTypes && + "materialization is not necessary"); // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. @@ -1455,7 +1450,9 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { - // Find a replacement value with the same type. + // Try to find a replacement value with the same type in the conversion value + // mapping. This includes cached materializations. We try to reuse those + // instead of generating duplicate IR. ValueVector repl = mapping.lookupOrNull(value, value.getType()); if (!repl.empty()) return repl.front(); @@ -1489,10 +1486,6 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // in the conversion value mapping.) The insertion point of the // materialization must be valid for all future users that may be created // later in the conversion process. - // - // Note: Instead of creating new IR, `buildUnresolvedMaterialization` may - // return an already existing, cached materialization from the conversion - // value mapping. Value castValue = buildUnresolvedMaterialization(MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),