diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index 6b4cdbea8b53..b7b78bc27f68 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -258,14 +258,40 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { } // Emit the barrier and exports. + // If any of the exports are in-place we need to alias their storage to the + // provided buffers. Value coarseSignalFence = entryBlock->getArgument(entryBlock->getNumArguments() - 1); if (barrierInputs.empty()) { postambleBuilder.create(funcOp.getLoc(), coarseSignalFence); } else { + SmallVector aliasedResults; + for (auto [barrierInput, meta] : + llvm::zip_equal(barrierInputs, barrierResultMeta)) { + Value exportStorage; + Type torchType; + int returnIndex; + std::tie(exportStorage, torchType, returnIndex) = meta; + if (exportStorage) { + // Use the wait fence indicating when the storage is available for + // mutation. We need to ensure that no writes are made to the storage + // until it indicates it's safe to do so. + auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage); + assert(waitSignalFences && "async function missing fences"); + Value waitFence = waitSignalFences->first; + auto barrierInputDims = IREE::Util::buildDynamicDimsForValue( + barrierInput.getLoc(), barrierInput, postambleBuilder); + aliasedResults.push_back( + postambleBuilder.create( + barrierInput.getLoc(), barrierInput.getType(), barrierInput, + barrierInputDims, exportStorage, waitFence)); + } else { + aliasedResults.push_back(barrierInput); + } + } auto barrierOp = postambleBuilder.create( - funcOp.getLoc(), barrierInputs, coarseSignalFence); + funcOp.getLoc(), aliasedResults, coarseSignalFence); for (auto [barrierResult, meta] : llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) { Value exportStorage; @@ -275,13 +301,9 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { Value exportedValue = postambleBuilder.create( funcOp.getLoc(), postambleBuilder.getType(), barrierResult, - TypeAttr::get(barrierResult.getType()), exportStorage, StringAttr()); + TypeAttr::get(barrierResult.getType()), StringAttr()); if (returnIndex >= 0) { newReturnOperands[returnIndex] = exportedValue; - } else { - // Don't drop it. - postambleBuilder.create( - funcOp.getLoc(), exportedValue); } } } diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir index d64bc2cd9f60..3e167ad7ba56 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir @@ -73,11 +73,11 @@ func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32> // CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]]) // CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]] // CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]] -// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[TENSOR_ARG1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence -// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#0 into(%arg1 : !hal.buffer_view) -// CHECK-DAG: %[[UNUSED:.+]] = util.optimization_barrier %[[EXPORT_RESULT1]] -// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#1 : -// CHECK: util.return %[[EXPORT_RESULT0]] +// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view +// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence +// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#0 +// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#1 +// CHECK: util.return %[[EXPORT_RESULT1]] builtin.module @mutable_input_overwrite_no_return { func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[4,5],si32>) { @@ -97,9 +97,10 @@ func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.tensor<[5,4],f3 // Not a good idea to do but legal. This verifies that if returning a mutated // tensor's intermediate value, you will get two exports, indicating a copy. // CHECK-LABEL: @mutable_input_overwrite_return_alias_copies -// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%{{.*}}, %{{.*}} : tensor<5x4xf32>, tensor<5x4xf32>) -// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#0 into(%arg0 : !hal.buffer_view) -// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#1 : +// CHECK: %[[ALIASED:.+]] = hal.tensor.alias wait({{.+}}) => %{{.+}} : tensor<5x4xf32> to %arg0 : !hal.buffer_view +// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[ALIASED]], %{{.*}} : tensor<5x4xf32>, tensor<5x4xf32>) +// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#0 +// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#1 builtin.module @mutable_input_overwrite_return_alias_copies { func.func @main(%arg0: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[5,4],f32>) { %0 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[5,4],f32> diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index 1977cebec0cf..a5fd2fe77e1b 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -560,6 +560,20 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, exportOp, arguments); auto asyncResults = llvm::to_vector(callOp.getResults()); + // Alias results to storage buffers if provided. + for (unsigned resultIndex = 0; resultIndex < asyncResults.size(); + ++resultIndex) { + if (!resultStorages[resultIndex]) + continue; + auto source = asyncResults[resultIndex]; + auto sourceDims = IREE::Util::buildDynamicDimsForValue( + exportOp.getLoc(), source, entryBuilder); + auto aliasOp = entryBuilder.create( + exportOp.getLoc(), source.getType(), source, sourceDims, + resultStorages[resultIndex], waitFence); + asyncResults[resultIndex] = cast(aliasOp.getResult()); + } + // Insert a barrier if requested - all tensors will be calculated and the // fence will be signaled. Note that even if there are no tensor results we // need to signal the fence. @@ -591,11 +605,9 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, resultIndex, "iree.abi.encoding"); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( result.getLoc(), result, entryBuilder); - auto resultStorage = resultStorages[resultIndex]; results.push_back(entryBuilder.create( result.getLoc(), newType, result, encoding ? encoding : TypeAttr::get(result.getType()), dynamicDims, - resultStorage, inferResultName(entryBuilder.getContext(), resultIndex, exportOp.getResultAttrDict(resultIndex)))); } else { diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir index d1f4751d4ce7..3780ee21a59e 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir @@ -105,10 +105,11 @@ util.func public @exportEncodings(%arg0: tensor {iree.abi.encoding // CHECK-NEXT: %[[ARG0_DIM0:.+]] = hal.buffer_view.dim<%[[ARG0]] : !hal.buffer_view>[0] : index // CHECK-NEXT: %[[ARG0_TENSOR:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor{%[[ARG0_DIM0]]} // CHECK-NEXT: %[[RET_TENSORS:.+]]:2 = util.call @_outputStorage(%[[ARG0_TENSOR]], %[[RET1_STORAGE]]) -// CHECK: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#0, %c0{{.*}} : tensor +// CHECK-DAG: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#1, %c0{{.*}} : tensor +// CHECK-DAG: %[[RET1_ALIAS:.+]] = hal.tensor.alias %[[RET_TENSORS]]#1 : tensor{%[[RET1_DIM0]]} to %[[RET1_STORAGE]] : !hal.buffer +// CHECK-DAG: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#0, %c0{{.*}} : tensor // CHECK-NEXT: %[[RET0_VIEW:.+]] = hal.tensor.export %[[RET_TENSORS]]#0 "output0" : tensor{%[[RET0_DIM0]]} -> !hal.buffer_view -// CHECK: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#1, %c0{{.*}} : tensor -// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.export %[[RET_TENSORS]]#1 "output1" into(%[[RET1_STORAGE]] : !hal.buffer) : tensor{%[[RET1_DIM0]]} -> !hal.buffer_view +// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.export %[[RET1_ALIAS]] "output1" : tensor{%[[RET1_DIM0]]} -> !hal.buffer_view // CHECK-NEXT: util.return %[[RET0_VIEW]], %[[RET1_VIEW]] : !hal.buffer_view, !hal.buffer_view // CHECK-NEXT: } diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir index 5b37ee4b9096..4505a54da10e 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir @@ -102,6 +102,26 @@ util.func public @tensorResultOnly() -> tensor<4xf32> { // ----- +// CHECK-LABEL: util.func public @outputStorage +// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view, %[[RET0:.+]]: !hal.buffer, %[[RET1:.+]]: !hal.buffer, +// CHECK-SAME: %[[WAIT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence) +// CHECK: %[[RESULT_TENSORS:.+]]:2 = util.call @_outputStorage +// CHECK-DAG: %[[RESULT_ALIAS0:.+]] = hal.tensor.alias wait(%[[WAIT]]) => %[[RESULT_TENSORS]]#0 : tensor<4xf32> to %[[RET0]] : !hal.buffer +// CHECK-DAG: %[[RESULT_ALIAS1:.+]] = hal.tensor.alias wait(%[[WAIT]]) => %[[RESULT_TENSORS]]#1 : tensor<4xf32> to %[[RET1]] : !hal.buffer +// CHECK-DAG: %[[READY_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[RESULT_ALIAS0]], %[[RESULT_ALIAS1]] : tensor<4xf32>, tensor<4xf32>) => %[[SIGNAL]] : !hal.fence +// CHECK-DAG: %[[EXPORT0:.+]] = hal.tensor.export %[[READY_RESULTS]]#0 "output0" +// CHECK-DAG: %[[EXPORT1:.+]] = hal.tensor.export %[[READY_RESULTS]]#1 "output1" +// CHECK-NEXT: util.return %[[EXPORT0]], %[[EXPORT1]] + +// CHECK-LABEL: util.func private @_outputStorage( +util.func public @outputStorage(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %ret0: !hal.buffer {iree.abi.output = 0 : index}, %ret1: !hal.buffer {iree.abi.output = 1 : index}) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = arith.addf %arg0, %arg1 : tensor<4xf32> + %1 = arith.addf %0, %arg0 : tensor<4xf32> + util.return %0, %1 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + // Tests that imported functions with the coarse-fences execution model // specified get wrapped with fences. Note that unlike exports controlled by // compiler flags imports only get the fences when explicitly specified so as diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index 763e4cd321bd..1ec24e84d8a6 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp @@ -539,7 +539,7 @@ class WrapEntryPointsPass } callResults.push_back(entryBuilder.create( result.getLoc(), bufferType, result, outputDynamicDims.tensorType, - dynamicDims, /*target_storage=*/nullptr, /*name=*/nullptr)); + dynamicDims, /*name=*/nullptr)); for (auto [dynamicDim, globalOp] : llvm::zip_equal(dynamicDims, outputDynamicDims.globalOps)) { globalOp.createStoreOp(result.getLoc(), dynamicDim, entryBuilder); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp index 5d6c3fef4137..ecfb86a2ae36 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp @@ -147,19 +147,24 @@ static IREE::Util::GlobalOp createExportBufferGlobalOp(std::string name, Explorer &explorer) { auto loc = arg.getLoc(); - // Find a hal.tensor.export user. - IREE::HAL::TensorExportOp exportOp; + // Find a hal.tensor.export or alias user and extract the encoding. + Type sourceType; if (explorer.walkTransitiveUsers(arg, [&](Operation *op) -> WalkResult { - exportOp = dyn_cast(op); - return exportOp ? WalkResult::interrupt() : WalkResult::advance(); + if (auto aliasOp = dyn_cast(op)) { + sourceType = aliasOp.getResult().getType(); + return WalkResult::interrupt(); + } else if (auto exportOp = dyn_cast(op)) { + sourceType = exportOp.getSourceEncoding(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); }) == TraversalResult::INCOMPLETE) { // Analysis failed to find an export op. User needs to rework their program. mlir::emitError(loc) << "unsupported dynamic buffer view export on " << arg; return {}; } - // Extract the type, which must be a static tensor. - auto sourceType = exportOp.getSourceEncoding(); + // The type must be a static tensor for this pass to work. auto tensorType = llvm::dyn_cast(sourceType); if (!tensorType || !tensorType.hasStaticShape()) { mlir::emitError(loc) << "unsupported buffer view export tensor type on " diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir index 03f5b56fc369..44e76a49ad57 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir @@ -95,8 +95,9 @@ util.func public @importDynamicBufferView(%view: !hal.buffer_view) -> !hal.buffe util.func public @exportBufferViewInPlace(%view: !hal.buffer_view, %storage: !hal.buffer) -> !hal.buffer_view { %0 = hal.tensor.import %view : !hal.buffer_view -> tensor<4xi32> %1 = arith.muli %0, %0 : tensor<4xi32> - %2 = hal.tensor.export %1 into(%storage : !hal.buffer) : tensor<4xi32> -> !hal.buffer_view - util.return %2 : !hal.buffer_view + %2 = hal.tensor.alias %1 : tensor<4xi32> to %storage : !hal.buffer + %3 = hal.tensor.export %2 : tensor<4xi32> -> !hal.buffer_view + util.return %3 : !hal.buffer_view } // CHECK: util.global private @[[GLOBAL_ARG0:.+]] { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 40f63f811516..4a0c36adc322 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -474,18 +474,9 @@ LogicalResult TensorImportOp::verify() { void TensorExportOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value source, TypeAttr sourceEncoding, StringAttr name) { - build(builder, result, resultType, source, sourceEncoding, - /*targetStorage=*/nullptr, name); -} - -void TensorExportOp::build(OpBuilder &builder, OperationState &result, - Type resultType, Value source, - TypeAttr sourceEncoding, Value targetStorage, - StringAttr name) { auto dynamicDims = IREE::Util::buildDynamicDimsForValue(result.location, source, builder); - build(builder, result, resultType, source, sourceEncoding, dynamicDims, - targetStorage, name); + build(builder, result, resultType, source, sourceEncoding, dynamicDims, name); } Value TensorExportOp::getTiedResult(unsigned resultIndex) { @@ -512,6 +503,33 @@ LogicalResult TensorExportOp::verify() { op.getSource().getType()); } +//===----------------------------------------------------------------------===// +// hal.tensor.alias +//===----------------------------------------------------------------------===// + +Value TensorAliasOp::getTiedResult(unsigned resultIndex) { + return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); +} + +::std::optional +TensorAliasOp::getTiedResultOperandIndex(unsigned resultIndex) { + return {0}; // source +} + +SmallVector TensorAliasOp::getTiedResultOperandIndices() { + return {0}; // source +} + +LogicalResult TensorAliasOp::verify() { + TensorAliasOp op = *this; + auto type = llvm::cast(op.getSource().getType()); + if (type.getNumDynamicDims() != op.getSourceDims().size()) { + return op->emitOpError() + << "number of dynamic dims must match the operand type"; + } + return success(); +} + //===----------------------------------------------------------------------===// // hal.tensor.barrier //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 92894fec79df..32551f6ccf3c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -167,7 +167,6 @@ def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ } def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ - AttrSizedOperandSegments, DeclareOpInterfaceMethods>:$target_storage, OptionalAttr:$name ); let results = (outs @@ -206,7 +199,6 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ let assemblyFormat = [{ $source ($name^)? - (`into` `(` $target_storage^ `:` type($target_storage) `)`)? `:` custom($source_encoding, type($source)) (`{` $source_dims^ `}`)? `->` @@ -221,13 +213,6 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ "TypeAttr":$sourceEncoding, "StringAttr":$name )>, - OpBuilder<(ins - "Type":$resultType, - "Value":$source, - "TypeAttr":$sourceEncoding, - "Value":$targetStorage, - "StringAttr":$name - )>, ]; let extraClassDeclaration = [{ @@ -240,6 +225,76 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ let hasFolder = 1; } +// TODO(#17328): specify an allocation policy to control behavior. +def HAL_TensorAliasOp : HAL_PureOp<"tensor.alias", [ + AllTypesMatch<["source", "result"]>, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + Util_ShapeAwareOp, +]> { + let summary = [{hints that tensor storage should alias a HAL buffer view}]; + let description = [{ + Hints that the backing storage of an entire tensor aliases the given storage + buffer. There's no guarantee that the storage will alias and instead only + that the tensor contents will be written to the storage as if a copy had + occurred. This allows the compiler to avoid copies in the ideal case of a + producer that is able to produce directly into the target storage but still + handle cases where the producer is not able to be in-place. + + The storage buffer provided must have sufficient space for the tensor once + encoded. Dynamically shaped tensors may not consume the entire provided + storage. If a buffer view is provided the metadata is ignored and only the + backing buffer is used. + + An optional wait fence can be provided in cases where the storage is not + immediately available. Producers that may alias the storage will wait until + the storage is available before updating the contents. + + Explicit aliasing side-steps any analysis that may be performed by the + compiler and requires users to guarantee that the safety of the aliasing. + Copy-on-write, alias analysis for overlap detection, and ordering via + use-def chains are all ignorant of the aliased buffer memory and only ensure + the compiler consumes or produces the aliased memory consistent with itself. + + Example: + ```mlir + %init = tensor.empty + %value = linalg.generic ... outs(%init) + %aliased = hal.tensor.alias %value : tensor<...> to %buffer : !hal.buffer + ... linalg.generic ins(%aliased) ... + ``` + }]; + + let arguments = (ins + AnyTensor:$source, + HAL_ShapeDynamicDims:$source_dims, + AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$storage, + Optional:$wait_fence + ); + let results = (outs + AnyTensor:$result + ); + + let assemblyFormat = [{ + (`wait` `(` $wait_fence^ `)` `=` `` `>`)? + $source `:` type($source) (`{` $source_dims^ `}`)? + `to` + $storage `:` type($storage) + attr-dict + }]; + + let extraClassDeclaration = [{ + ValueRange getOperandDynamicDims(unsigned idx) { return getSourceDims(); } + ValueRange getResultDynamicDims(unsigned idx) { return getSourceDims(); } + }]; + + let hasVerifier = 1; +} + def HAL_TensorBarrierOp : HAL_Op<"tensor.barrier", [ AllTypesMatch<["sources", "results"]>, DeclareOpInterfaceMethods, %arg1: index) -> ! // ----- -// CHECK-LABEL: @tensorExportInPlace -util.func public @tensorExportInPlace(%arg0: tensor, %arg1: index, %arg2: !hal.buffer) -> !hal.buffer_view { - // CHECK: hal.tensor.export %arg0 into(%arg2 : !hal.buffer) : tensor as tensor{%arg1} -> !hal.buffer_view - %0 = hal.tensor.export %arg0 into(%arg2 : !hal.buffer) : tensor as tensor{%arg1} -> !hal.buffer_view - util.return %0 : !hal.buffer_view +// CHECK-LABEL: @tensorAlias +util.func public @tensorAlias(%arg0: tensor, %arg1: index, %arg2: !hal.buffer, %arg3: !hal.fence) -> tensor { + // CHECK: hal.tensor.alias wait(%arg3) => %arg0 : tensor{%arg1} to %arg2 : !hal.buffer + %0 = hal.tensor.alias wait(%arg3) => %arg0 : tensor{%arg1} to %arg2 : !hal.buffer + util.return %0 : tensor } // ----- diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index fc1b7f181930..2acd3ba4f51c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -136,53 +136,106 @@ struct ConvertTensorExportOp auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + + // Exporting a produced value - transfer our source value to an externally + // usable resource and directly export it. This will cause an allocation. + auto exportSource = adaptor.getSource(); auto externalType = rewriter.getType( IREE::Stream::Lifetime::External); - auto exportSource = adaptor.getSource(); - auto exportSize = source.resourceSize; - if (adaptor.getTargetStorage()) { - // Query the target storage buffer length; we will only populate up to - // what is required for the output. - auto storageSize = rewriter.createOrFold( - op.getLoc(), rewriter.getIndexType(), - TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), - affinityAttr); - - // Import the target storage as a resource that we can use as an update - // target. We overwrite the contents and just cast the storage to the - // target type so we know we can update it. - auto importOp = rewriter.create( - op.getLoc(), externalType, adaptor.getTargetStorage(), - TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize, - affinityAttr); - - // Copy the source value into the imported target storage. - auto zeroOffset = rewriter.create(op.getLoc(), 0); - auto updateOp = rewriter.create( - op.getLoc(), externalType, importOp.getResult(), - importOp.getResultSize(), zeroOffset, source.resourceSize, - source.resource, source.resourceSize, affinityAttr); - - // Export the updated resource. - // NOTE: the buffer size wrapped in the buffer view is the full size of - // the input buffer. This is so that we don't insert a data dependency on - // sparse operations or data-dependent dynamic shape dimensions. - exportSource = updateOp.getResult(); - exportSize = updateOp.getTargetSize(); - } else { - // Exporting a produced value - transfer our source value to an externally - // usable resource and directly export it. This will cause an allocation. - if (source.resource.getType() != externalType) { - exportSource = rewriter.create( - op.getLoc(), externalType, source.resource, source.resourceSize, - source.resourceSize, affinityAttr, affinityAttr); - } + if (source.resource.getType() != externalType) { + exportSource = rewriter.create( + op.getLoc(), externalType, source.resource, source.resourceSize, + source.resourceSize, affinityAttr, affinityAttr); } // Export (stream resource to buffer view). rewriter.replaceOpWithNewOp( op, targetType, exportSource, TypeAttr::get(sourceType), - adaptor.getSourceDims(), exportSize, affinityAttr); + adaptor.getSourceDims(), source.resourceSize, affinityAttr); + return success(); + } +}; + +// Imports the storage to alias as a resource, copies the source value into it, +// and slices out the source value. This should allow allocation placement to +// elide the update (and subsequently the slice) if possible and otherwise will +// turn into a copy. +// +// Effectively: +// %2 = hal.tensor.alias %0 : tensor<4xf32> to %1 : !hal.buffer_view +// -> +// %storage = stream.tensor.import %1 : !hal.buffer -> tensor<...> +// %update = stream.async.update %0, %storage[...] +// %2 = stream.async.slice %update[...] +struct ConvertTensorAliasOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::HAL::TensorAliasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = op.getSource().getType(); + auto source = + consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + + // All operations (if any) will happen on the device specified by the alias + // as that indicates the affinity of the storage. + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + + // Query the target storage buffer length; we will only populate up to + // what is required for the output. + auto storageSize = rewriter.createOrFold( + op.getLoc(), rewriter.getIndexType(), + TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), + affinityAttr); + + // Import the target storage as a resource that we can use as an update + // target. We overwrite the contents and just cast the storage to the + // target type so we know we can update it. + auto externalType = rewriter.getType( + IREE::Stream::Lifetime::External); + auto importOp = rewriter.create( + op.getLoc(), externalType, adaptor.getStorage(), + TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize, + affinityAttr); + + // Await the fence, if needed. When not specified the storage is assumed to + // be immediately available. + Value storage = importOp.getResult(); + if (auto waitFence = op.getWaitFence()) { + Value waitTimepoint = rewriter.create( + op.getLoc(), rewriter.getType(), + ValueRange{waitFence}, affinityAttr); + storage = rewriter + .create( + op.getLoc(), ValueRange{storage}, + ValueRange{storageSize}, waitTimepoint) + .getResult(0); + } + + // Copy the source value into the imported target storage. + auto zeroOffset = rewriter.create(op.getLoc(), 0); + auto updateOp = rewriter.create( + op.getLoc(), externalType, storage, storageSize, zeroOffset, + source.resourceSize, source.resource, source.resourceSize, + affinityAttr); + + // Slice out the value from the updated tensor. + // This preserves the use-def chain but is almost always elided by aliasing + // the input value later on. + auto sliceOp = rewriter.create( + op.getLoc(), externalType, updateOp.getResult(), + updateOp.getTargetSize(), zeroOffset, source.resourceSize, + source.resourceSize, affinityAttr); + + // Transfer to match original lifetime (if needed). + Value result = sliceOp.getResult(); + if (source.resource.getType() != result.getType()) { + result = rewriter.create( + op.getLoc(), source.resource.getType(), result, source.resourceSize, + source.resourceSize, affinityAttr, affinityAttr); + } + rewriter.replaceOp(op, result); + return success(); } }; @@ -233,6 +286,7 @@ void populateHALToStreamConversionPatterns(MLIRContext *context, [](IREE::HAL::BufferViewType type) { return type; }); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir index 8dc67055407c..c96596f28389 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir @@ -69,32 +69,47 @@ util.func public @exportBufferView(%tensor: tensor, %dim0: index, %di // ----- -// CHECK-LABEL: @exportBufferViewInPlace -// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[STORAGE:.+]]: !hal.buffer) -util.func public @exportBufferViewInPlace(%tensor: tensor, %dim0: index, %dim1: index, %storage: !hal.buffer) -> !hal.buffer_view { - // CHECK: %[[STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]], %[[DIM1]]} : index - // CHECK-NEXT: %[[STORAGE_IMPORT:.+]] = stream.tensor.import %[[STORAGE]] - // CHECK-SAME: : !hal.buffer -> tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[STORAGE_SIZE]]} - // CHECK-NEXT: %[[STORAGE_UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[STORAGE_IMPORT]][%c0 to %[[SIZE]]] - // CHECK-SAME: : !stream.resource<*>{%[[SIZE]]} -> %[[STORAGE_IMPORT]] as !stream.resource{%[[STORAGE_SIZE]]} - // CHECK-NEXT: %[[STORAGE_RESULT:.+]] = stream.tensor.export %[[STORAGE_UPDATE]] : - // CHECK-SAME: tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[STORAGE_SIZE]]} - // CHECK-SAME: -> !hal.buffer_view - %0 = hal.tensor.export %tensor into(%storage : !hal.buffer) : tensor{%dim0, %dim1} -> !hal.buffer_view - // CHECK: util.return %[[STORAGE_RESULT]] - util.return %0 : !hal.buffer_view +// CHECK-LABEL: @aliasStorage +// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[STORAGE:.+]]: !hal.buffer) +util.func public @aliasStorage(%tensor: tensor, %dim0: index, %storage: !hal.buffer) -> tensor { + // CHECK: %[[MIN_STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]]} + // CHECK: %[[STORAGE_RESOURCE:.+]] = stream.tensor.import %[[STORAGE]] : !hal.buffer -> tensor{%[[DIM0]]} in !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[STORAGE_RESOURCE]][%c0 to %[[SIZE]]] : !stream.resource<*>{%[[SIZE]]} -> %[[STORAGE_RESOURCE]] as !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[SLICE:.+]] = stream.async.slice %[[UPDATE]][%c0 to %[[SIZE]]] : !stream.resource{%[[MIN_STORAGE_SIZE]]} -> !stream.resource{%[[SIZE]]} + // CHECK: %[[RESULT:.+]] = stream.async.transfer %[[SLICE]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} + %0 = hal.tensor.alias %tensor : tensor{%dim0} to %storage : !hal.buffer + // CHECK: util.return %[[RESULT]] + util.return %0 : tensor } // ----- -// As with @exportBufferViewInPlace above but using !hal.buffer_view storage. +// CHECK-LABEL: @aliasStorageAsync +// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[STORAGE:.+]]: !hal.buffer, %[[FENCE:.+]]: !hal.fence) +util.func public @aliasStorageAsync(%tensor: tensor, %dim0: index, %storage: !hal.buffer, %fence: !hal.fence) -> tensor { + // CHECK-DAG: %[[MIN_STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]]} + // CHECK-DAG: %[[UNREADY_STORAGE:.+]] = stream.tensor.import %[[STORAGE]] : !hal.buffer -> tensor{%[[DIM0]]} in !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK-DAG: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[FENCE]] + // CHECK-DAG: %[[READY_STORAGE:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[UNREADY_STORAGE]] : !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[READY_STORAGE]][%c0 to %[[SIZE]]] : !stream.resource<*>{%[[SIZE]]} -> %[[READY_STORAGE]] as !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[SLICE:.+]] = stream.async.slice %[[UPDATE]][%c0 to %[[SIZE]]] : !stream.resource{%[[MIN_STORAGE_SIZE]]} -> !stream.resource{%[[SIZE]]} + // CHECK: %[[RESULT:.+]] = stream.async.transfer %[[SLICE]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} + %0 = hal.tensor.alias wait(%fence) => %tensor : tensor{%dim0} to %storage : !hal.buffer + // CHECK: util.return %[[RESULT]] + util.return %0 : tensor +} -// CHECK-LABEL: @exportBufferViewInPlaceToView -// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[STORAGE:.+]]: !hal.buffer_view) -util.func public @exportBufferViewInPlaceToView(%tensor: tensor, %dim0: index, %dim1: index, %storage: !hal.buffer_view) -> !hal.buffer_view { - // CHECK: %[[STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]], %[[DIM1]]} : index - // CHECK-NEXT: %[[STORAGE_IMPORT:.+]] = stream.tensor.import %[[STORAGE]] - // CHECK-SAME: : !hal.buffer_view -> tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[STORAGE_SIZE]]} - %0 = hal.tensor.export %tensor into(%storage : !hal.buffer_view) : tensor{%dim0, %dim1} -> !hal.buffer_view - util.return %0 : !hal.buffer_view +// ----- + +// CHECK-LABEL: @tensorBarrier +// CHECK-SAME: (%[[TENSOR0:.+]]: !stream.resource<*>, %[[SIZE0:.+]]: index, %[[TENSOR1:.+]]: !stream.resource<*>, %[[SIZE1:.+]]: index, %[[FENCE:.+]]: !hal.fence) +util.func public @tensorBarrier(%tensor0: tensor<3xf32>, %tensor1: tensor, %fence: !hal.fence) -> (tensor<3xf32>, tensor) { + // CHECK-DAG: %[[TENSOR0_AFTER:.+]], %[[TENSOR0_BARRIER:.+]] = stream.timepoint.barrier %[[TENSOR0]] : !stream.resource<*>{%[[SIZE0]]} => !stream.timepoint + // CHECK-DAG: %[[TENSOR1_AFTER:.+]], %[[TENSOR1_BARRIER:.+]] = stream.timepoint.barrier %[[TENSOR1]] : !stream.resource<*>{%[[SIZE1]]} => !stream.timepoint + // CHECK-NEXT: %[[JOIN:.+]] = stream.timepoint.join max(%[[TENSOR0_BARRIER]], %[[TENSOR1_BARRIER]]) => !stream.timepoint + // CHECK-NEXT: stream.timepoint.chain_external %[[JOIN]] => (%[[FENCE]] : !hal.fence) + %0:2 = hal.tensor.barrier join(%tensor0, %tensor1 : tensor<3xf32>, tensor) => %fence : !hal.fence + // CHECK: util.return %[[TENSOR0_AFTER]], %[[SIZE0]], %[[TENSOR1_AFTER]], %[[SIZE1]] + util.return %0#0, %0#1 : tensor<3xf32>, tensor } + diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 3b0df17950d3..3982c4e2d49e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -1745,6 +1745,95 @@ OpFoldResult AsyncUpdateOp::fold(FoldAdaptor operands) { namespace { +// Detects updates that are overwriting entire tensors that could be folded into +// an in-place producer operation. +// +// Example: +// %1 = stream.async.dispatch ... %0 -> %0 +// %2 = stream.async.update %1, %0[full] +// -> +// %2 = stream.async.dispatch .... %0 -> %0 +struct ElideInPlaceUpdate : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AsyncUpdateOp updateOp, + PatternRewriter &rewriter) const override { + // Look for entire tensor replacement. + const bool isOverwrite = + updateOp.getUpdateSize() == updateOp.getTargetSize() && + updateOp.getUpdate().getType() == updateOp.getType(); + if (!isOverwrite) { + // Ignore partial updates. + return failure(); + } + + // Detect producers that are performing their operation in-place. + // This should be a global analysis to detect in-place operations across + // control flow/calls + auto updateOperand = + IREE::Util::TiedOpInterface::findTiedBaseValue(updateOp.getUpdate()); + auto targetOperand = + IREE::Util::TiedOpInterface::findTiedBaseValue(updateOp.getTarget()); + + // Condition: if overwriting the entire tied operand then this is a no-op. + if (updateOperand != targetOperand) { + return rewriter.notifyMatchFailure( + updateOp, + "update overwrite target is not tied to the producer operand"); + } + + // If there are multiple users we may need them for copies. We need to + // ensure that any uses of the produced update are reads - if not + // copy-on-write will require the update op to exist in order to identify + // the condition. + SmallVector accessRanges; + for (auto &use : updateOp.getUpdate().getUses()) { + // Ops that are async resource access aware let us see if the particular + // source update resource is written. Any tied ops will have their uses + // marked as writes so that we don't need to walk down all transitive + // users to detect writes. + if (auto accessOp = + dyn_cast(use.getOwner())) { + accessOp.getAsyncAccessRanges(accessRanges); + for (auto &accessRange : accessRanges) { + if (accessRange.resource == updateOp.getUpdate() && + !accessRange.isReadOnly()) { + // TODO(benvanik): allow non-overlapping writes by checking + // accessRange.mayOverlap - may not be worth it due to the update + // source being the entire resource. If we did this on copies as + // well we'd want that. + return rewriter.notifyMatchFailure( + updateOp, "usage writes the update resource and conservatively " + "blocks elision"); + } + } + accessRanges.clear(); + continue; + } + + // Use memory effect analysis for ops in other dialects that don't + // indicate their access ranges. We conservatively fail if the op doesn't + // declare memory effects. + std::optional> effects = + getEffectsRecursively(use.getOwner()); + if (!effects) { + // Effect analysis failed. + return rewriter.notifyMatchFailure( + updateOp, "usage has unknown memory effects and blocks elision"); + } + for (const MemoryEffects::EffectInstance &effect : *effects) { + if (isa(effect.getEffect())) { + // Write effect indicates something we can't analyze (today). + return rewriter.notifyMatchFailure( + updateOp, "usage has side-effects and blocks elision"); + } + } + } + + rewriter.replaceOp(updateOp, updateOp.getUpdate()); + return success(); + } +}; + // Turns a splat+update-from into a fill. // // Example: @@ -1816,6 +1905,7 @@ void AsyncUpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, // affinity/lifetime differ. // TODO(#6972): updates into splats could become alloca + fill exclusive // region + update into undefined contents (used in padding). + results.insert(context); results.insert(context); results.insert(context); results.insert>(context); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir index 64ae8f3a2591..155a857205ca 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir @@ -288,6 +288,68 @@ util.func private @FoldLocalAsyncUpdateOp(%arg0: !stream.resource<*>, %arg1: ind // ----- +// CHECK-LABEL: @ElideInPlaceUpdateUpdate +util.func private @ElideInPlaceUpdateUpdate(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.+]] = stream.async.update %arg0, %arg2[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg2 as !stream.resource<*>{%arg3} + %0 = stream.async.update %arg0, %arg2[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg2 as !stream.resource<*>{%arg3} + // CHECK-NOT: stream.async.update + %1 = stream.async.update %0, %arg2[%c0 to %arg3] : !stream.resource<*>{%arg3} -> %arg2 as !stream.resource<*>{%arg3} + // CHECK: util.return %[[RESULT]] + util.return %1 : !stream.resource<*> +} + +// ----- + +// CHECK-LABEL: @ElideInPlaceUpdateDispatch +util.func private @ElideInPlaceUpdateDispatch(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@fn(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + %0 = stream.async.dispatch @ex::@fn(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + // CHECK-NOT: stream.async.update + %1 = stream.async.update %0, %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg0 as !stream.resource<*>{%arg1} + // CHECK: util.return %[[RESULT]] + util.return %1 : !stream.resource<*> +} + +// ----- + +// Tests that multiple users of the produced value will still allow the update +// to be elided so long as they are reads. + +// CHECK-LABEL: @ElideInPlaceUpdateDispatchMultiUse +util.func private @ElideInPlaceUpdateDispatchMultiUse(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>) { + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT0:.+]] = stream.async.dispatch @ex::@fn0(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + %0 = stream.async.dispatch @ex::@fn0(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + // CHECK-NOT: stream.async.update + %1 = stream.async.update %0, %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg0 as !stream.resource<*>{%arg1} + // CHECK: %[[RESULT1:.+]] = stream.async.dispatch @ex::@fn1(%[[RESULT0]][%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} + %2 = stream.async.dispatch @ex::@fn1(%0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} + // CHECK: util.return %[[RESULT0]], %[[RESULT1]] + util.return %1, %2 : !stream.resource<*>, !stream.resource<*> +} + +// ----- + +// Tests that writes on the update source will fail to elide the update. +// TODO(benvanik): support looking for writes only prior to the update that are +// known-safe. + +// CHECK-LABEL: @ElideInPlaceUpdateDispatchMultiUseWrite +util.func private @ElideInPlaceUpdateDispatchMultiUseWrite(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>) { + %c0 = arith.constant 0 : index + // CHECK: stream.async.dispatch @ex::@fn0 + %0 = stream.async.dispatch @ex::@fn0(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + // CHECK: stream.async.update + %1 = stream.async.update %0, %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg0 as !stream.resource<*>{%arg1} + // CHECK: stream.async.dispatch @ex::@fn1 + %2 = stream.async.dispatch @ex::@fn1(%0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %0{%arg1} + util.return %1, %2 : !stream.resource<*>, !stream.resource<*> +} + +// ----- + // CHECK-LABEL: @CombineSplatUpdateFromToFill util.func private @CombineSplatUpdateFromToFill(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> { %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp index 3d0d28e9d6e2..fcdb15f91635 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp @@ -206,11 +206,15 @@ bool tryMoveProducerBefore(Value value, Operation *consumerOp) { // satisfies the request based on SSA dominance. return true; } + + // Recursively try to move each operand. + // TODO(benvanik): change to a worklist to avoid potential stack explosion. for (auto operand : producerOp->getOperands()) { - if (!isValueUsableForOp(operand, consumerOp)) { + if (!tryMoveProducerBefore(operand, consumerOp)) { return false; } } + producerOp->moveBefore(consumerOp); return true; } diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index dececb2c6db0..1c3664803ff7 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -244,7 +244,7 @@ class TensorExportPattern rewriter.replaceOpWithNewOp( srcOp, resultType, adaptor.getSource(), TypeAttr::get(adaptor.getSource().getType()), adaptor.getSourceDims(), - /*target_storage=*/nullptr, /*name=*/nullptr); + /*name=*/nullptr); } return success(); }