From d2dd9e23eff96787a66871239c32c627fc3da51b Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Sat, 11 May 2024 08:51:54 -0700 Subject: [PATCH] Replacing hal.tensor.export storage for hal.tensor.alias. (#17339) This fixes a design issue in the original `hal.tensor.export` optional storage feature that would lead to the export happening after any `hal.tensor.barrier` ops that may have been used on the source tensor. The new op is intended to be inserted prior to the barriers and can also be inserted elsewhere (not just at ABI boundaries). Minor improvements were required to folding of `stream.async.update` in order to ensure the aliased buffers are used in cases where barriers are present between producers and the alias ops consuming the values. #17135 made the folder too conservative and would result in all in-place operations of external values getting extra copies. Fixes #17316. --- .../Torch/InputConversion/FuncConversion.cpp | 34 ++++- .../InputConversion/test/func_conversion.mlir | 17 +-- .../Native/Transforms/WrapEntryPoints.cpp | 16 ++- .../Transforms/test/wrap_entry_points.mlir | 7 +- .../test/wrap_entry_points_coarse_fences.mlir | 20 +++ .../TFLite/Transforms/WrapEntryPoints.cpp | 2 +- .../Flow/Transforms/ExportBenchmarkFuncs.cpp | 17 ++- .../test/export_benchmark_funcs.mlir | 5 +- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 38 +++-- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 85 +++++++++-- .../Dialect/HAL/IR/test/tensor_ops.mlir | 10 +- .../Conversion/HALToStream/Patterns.cpp | 134 ++++++++++++------ .../Conversion/HALToStream/test/abi_ops.mlir | 61 +++++--- .../Dialect/Stream/IR/StreamOpFolders.cpp | 90 ++++++++++++ .../Dialect/Stream/IR/test/async_folding.mlir | 62 ++++++++ .../compiler/Dialect/Util/IR/UtilTypes.cpp | 6 +- .../Common/IREEImportPublic.cpp | 2 +- 17 files changed, 483 insertions(+), 123 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp index 6b4cdbea8b53..b7b78bc27f68 100644 --- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp +++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp @@ -258,14 +258,40 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { } // Emit the barrier and exports. + // If any of the exports are in-place we need to alias their storage to the + // provided buffers. Value coarseSignalFence = entryBlock->getArgument(entryBlock->getNumArguments() - 1); if (barrierInputs.empty()) { postambleBuilder.create(funcOp.getLoc(), coarseSignalFence); } else { + SmallVector aliasedResults; + for (auto [barrierInput, meta] : + llvm::zip_equal(barrierInputs, barrierResultMeta)) { + Value exportStorage; + Type torchType; + int returnIndex; + std::tie(exportStorage, torchType, returnIndex) = meta; + if (exportStorage) { + // Use the wait fence indicating when the storage is available for + // mutation. We need to ensure that no writes are made to the storage + // until it indicates it's safe to do so. + auto waitSignalFences = getEnclosingWaitSignalFences(exportStorage); + assert(waitSignalFences && "async function missing fences"); + Value waitFence = waitSignalFences->first; + auto barrierInputDims = IREE::Util::buildDynamicDimsForValue( + barrierInput.getLoc(), barrierInput, postambleBuilder); + aliasedResults.push_back( + postambleBuilder.create( + barrierInput.getLoc(), barrierInput.getType(), barrierInput, + barrierInputDims, exportStorage, waitFence)); + } else { + aliasedResults.push_back(barrierInput); + } + } auto barrierOp = postambleBuilder.create( - funcOp.getLoc(), barrierInputs, coarseSignalFence); + funcOp.getLoc(), aliasedResults, coarseSignalFence); for (auto [barrierResult, meta] : llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) { Value exportStorage; @@ -275,13 +301,9 @@ LogicalResult ConvertedAsyncFunctionInfo::postProcess() { Value exportedValue = postambleBuilder.create( funcOp.getLoc(), postambleBuilder.getType(), barrierResult, - TypeAttr::get(barrierResult.getType()), exportStorage, StringAttr()); + TypeAttr::get(barrierResult.getType()), StringAttr()); if (returnIndex >= 0) { newReturnOperands[returnIndex] = exportedValue; - } else { - // Don't drop it. - postambleBuilder.create( - funcOp.getLoc(), exportedValue); } } } diff --git a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir index d64bc2cd9f60..3e167ad7ba56 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/func_conversion.mlir @@ -73,11 +73,11 @@ func.func @main(%arg0: !torch.vtensor<[4,5],si32>) -> !torch.vtensor<[4,5],si32> // CHECK-DAG: %[[TORCH_RESULT1:.+]] = torch.operator "mutate_inplace"(%[[TORCH_ARG1]]) // CHECK-DAG: %[[TENSOR_ARG0:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT0]] // CHECK-DAG: %[[TENSOR_ARG1:.+]] = torch_c.to_builtin_tensor %[[TORCH_RESULT1]] -// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[TENSOR_ARG1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence -// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#0 into(%arg1 : !hal.buffer_view) -// CHECK-DAG: %[[UNUSED:.+]] = util.optimization_barrier %[[EXPORT_RESULT1]] -// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#1 : -// CHECK: util.return %[[EXPORT_RESULT0]] +// CHECK: %[[EXPORT_ALIAS1:.+]] = hal.tensor.alias wait(%arg2) => %[[TENSOR_ARG1]] : tensor<5x4xf32> to %arg1 : !hal.buffer_view +// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[EXPORT_ALIAS1]], %[[TENSOR_ARG0]] : tensor<5x4xf32>, tensor<4x5xi32>) => %arg3 : !hal.fence +// CHECK-DAG: %[[EXPORT_RESULT0:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#0 +// CHECK-DAG: %[[EXPORT_RESULT1:.+]] = hal.tensor.export %[[BARRIER_RESULTS]]#1 +// CHECK: util.return %[[EXPORT_RESULT1]] builtin.module @mutable_input_overwrite_no_return { func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[4,5],si32>) { @@ -97,9 +97,10 @@ func.func @main(%arg0: !torch.vtensor<[4,5],si32>, %arg1: !torch.tensor<[5,4],f3 // Not a good idea to do but legal. This verifies that if returning a mutated // tensor's intermediate value, you will get two exports, indicating a copy. // CHECK-LABEL: @mutable_input_overwrite_return_alias_copies -// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%{{.*}}, %{{.*}} : tensor<5x4xf32>, tensor<5x4xf32>) -// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#0 into(%arg0 : !hal.buffer_view) -// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#1 : +// CHECK: %[[ALIASED:.+]] = hal.tensor.alias wait({{.+}}) => %{{.+}} : tensor<5x4xf32> to %arg0 : !hal.buffer_view +// CHECK: %[[BARRIER_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[ALIASED]], %{{.*}} : tensor<5x4xf32>, tensor<5x4xf32>) +// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#0 +// CHECK-DAG: = hal.tensor.export %[[BARRIER_RESULTS]]#1 builtin.module @mutable_input_overwrite_return_alias_copies { func.func @main(%arg0: !torch.tensor<[5,4],f32>) -> (!torch.vtensor<[5,4],f32>) { %0 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[5,4],f32> diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp index 1977cebec0cf..a5fd2fe77e1b 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp @@ -560,6 +560,20 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, exportOp, arguments); auto asyncResults = llvm::to_vector(callOp.getResults()); + // Alias results to storage buffers if provided. + for (unsigned resultIndex = 0; resultIndex < asyncResults.size(); + ++resultIndex) { + if (!resultStorages[resultIndex]) + continue; + auto source = asyncResults[resultIndex]; + auto sourceDims = IREE::Util::buildDynamicDimsForValue( + exportOp.getLoc(), source, entryBuilder); + auto aliasOp = entryBuilder.create( + exportOp.getLoc(), source.getType(), source, sourceDims, + resultStorages[resultIndex], waitFence); + asyncResults[resultIndex] = cast(aliasOp.getResult()); + } + // Insert a barrier if requested - all tensors will be calculated and the // fence will be signaled. Note that even if there are no tensor results we // need to signal the fence. @@ -591,11 +605,9 @@ createExportWrapperFunc(IREE::ABI::InvocationModel invocationModel, resultIndex, "iree.abi.encoding"); auto dynamicDims = IREE::Util::buildDynamicDimsForValue( result.getLoc(), result, entryBuilder); - auto resultStorage = resultStorages[resultIndex]; results.push_back(entryBuilder.create( result.getLoc(), newType, result, encoding ? encoding : TypeAttr::get(result.getType()), dynamicDims, - resultStorage, inferResultName(entryBuilder.getContext(), resultIndex, exportOp.getResultAttrDict(resultIndex)))); } else { diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir index d1f4751d4ce7..3780ee21a59e 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir @@ -105,10 +105,11 @@ util.func public @exportEncodings(%arg0: tensor {iree.abi.encoding // CHECK-NEXT: %[[ARG0_DIM0:.+]] = hal.buffer_view.dim<%[[ARG0]] : !hal.buffer_view>[0] : index // CHECK-NEXT: %[[ARG0_TENSOR:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor{%[[ARG0_DIM0]]} // CHECK-NEXT: %[[RET_TENSORS:.+]]:2 = util.call @_outputStorage(%[[ARG0_TENSOR]], %[[RET1_STORAGE]]) -// CHECK: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#0, %c0{{.*}} : tensor +// CHECK-DAG: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#1, %c0{{.*}} : tensor +// CHECK-DAG: %[[RET1_ALIAS:.+]] = hal.tensor.alias %[[RET_TENSORS]]#1 : tensor{%[[RET1_DIM0]]} to %[[RET1_STORAGE]] : !hal.buffer +// CHECK-DAG: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#0, %c0{{.*}} : tensor // CHECK-NEXT: %[[RET0_VIEW:.+]] = hal.tensor.export %[[RET_TENSORS]]#0 "output0" : tensor{%[[RET0_DIM0]]} -> !hal.buffer_view -// CHECK: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSORS]]#1, %c0{{.*}} : tensor -// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.export %[[RET_TENSORS]]#1 "output1" into(%[[RET1_STORAGE]] : !hal.buffer) : tensor{%[[RET1_DIM0]]} -> !hal.buffer_view +// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.export %[[RET1_ALIAS]] "output1" : tensor{%[[RET1_DIM0]]} -> !hal.buffer_view // CHECK-NEXT: util.return %[[RET0_VIEW]], %[[RET1_VIEW]] : !hal.buffer_view, !hal.buffer_view // CHECK-NEXT: } diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir index 5b37ee4b9096..4505a54da10e 100644 --- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir +++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points_coarse_fences.mlir @@ -102,6 +102,26 @@ util.func public @tensorResultOnly() -> tensor<4xf32> { // ----- +// CHECK-LABEL: util.func public @outputStorage +// CHECK-SAME: (%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view, %[[RET0:.+]]: !hal.buffer, %[[RET1:.+]]: !hal.buffer, +// CHECK-SAME: %[[WAIT:.+]]: !hal.fence, %[[SIGNAL:.+]]: !hal.fence) +// CHECK: %[[RESULT_TENSORS:.+]]:2 = util.call @_outputStorage +// CHECK-DAG: %[[RESULT_ALIAS0:.+]] = hal.tensor.alias wait(%[[WAIT]]) => %[[RESULT_TENSORS]]#0 : tensor<4xf32> to %[[RET0]] : !hal.buffer +// CHECK-DAG: %[[RESULT_ALIAS1:.+]] = hal.tensor.alias wait(%[[WAIT]]) => %[[RESULT_TENSORS]]#1 : tensor<4xf32> to %[[RET1]] : !hal.buffer +// CHECK-DAG: %[[READY_RESULTS:.+]]:2 = hal.tensor.barrier join(%[[RESULT_ALIAS0]], %[[RESULT_ALIAS1]] : tensor<4xf32>, tensor<4xf32>) => %[[SIGNAL]] : !hal.fence +// CHECK-DAG: %[[EXPORT0:.+]] = hal.tensor.export %[[READY_RESULTS]]#0 "output0" +// CHECK-DAG: %[[EXPORT1:.+]] = hal.tensor.export %[[READY_RESULTS]]#1 "output1" +// CHECK-NEXT: util.return %[[EXPORT0]], %[[EXPORT1]] + +// CHECK-LABEL: util.func private @_outputStorage( +util.func public @outputStorage(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %ret0: !hal.buffer {iree.abi.output = 0 : index}, %ret1: !hal.buffer {iree.abi.output = 1 : index}) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = arith.addf %arg0, %arg1 : tensor<4xf32> + %1 = arith.addf %0, %arg0 : tensor<4xf32> + util.return %0, %1 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + // Tests that imported functions with the coarse-fences execution model // specified get wrapped with fences. Note that unlike exports controlled by // compiler flags imports only get the fences when explicitly specified so as diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp index 763e4cd321bd..1ec24e84d8a6 100644 --- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp +++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp @@ -539,7 +539,7 @@ class WrapEntryPointsPass } callResults.push_back(entryBuilder.create( result.getLoc(), bufferType, result, outputDynamicDims.tensorType, - dynamicDims, /*target_storage=*/nullptr, /*name=*/nullptr)); + dynamicDims, /*name=*/nullptr)); for (auto [dynamicDim, globalOp] : llvm::zip_equal(dynamicDims, outputDynamicDims.globalOps)) { globalOp.createStoreOp(result.getLoc(), dynamicDim, entryBuilder); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp index 5d6c3fef4137..ecfb86a2ae36 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp @@ -147,19 +147,24 @@ static IREE::Util::GlobalOp createExportBufferGlobalOp(std::string name, Explorer &explorer) { auto loc = arg.getLoc(); - // Find a hal.tensor.export user. - IREE::HAL::TensorExportOp exportOp; + // Find a hal.tensor.export or alias user and extract the encoding. + Type sourceType; if (explorer.walkTransitiveUsers(arg, [&](Operation *op) -> WalkResult { - exportOp = dyn_cast(op); - return exportOp ? WalkResult::interrupt() : WalkResult::advance(); + if (auto aliasOp = dyn_cast(op)) { + sourceType = aliasOp.getResult().getType(); + return WalkResult::interrupt(); + } else if (auto exportOp = dyn_cast(op)) { + sourceType = exportOp.getSourceEncoding(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); }) == TraversalResult::INCOMPLETE) { // Analysis failed to find an export op. User needs to rework their program. mlir::emitError(loc) << "unsupported dynamic buffer view export on " << arg; return {}; } - // Extract the type, which must be a static tensor. - auto sourceType = exportOp.getSourceEncoding(); + // The type must be a static tensor for this pass to work. auto tensorType = llvm::dyn_cast(sourceType); if (!tensorType || !tensorType.hasStaticShape()) { mlir::emitError(loc) << "unsupported buffer view export tensor type on " diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir index 03f5b56fc369..44e76a49ad57 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir @@ -95,8 +95,9 @@ util.func public @importDynamicBufferView(%view: !hal.buffer_view) -> !hal.buffe util.func public @exportBufferViewInPlace(%view: !hal.buffer_view, %storage: !hal.buffer) -> !hal.buffer_view { %0 = hal.tensor.import %view : !hal.buffer_view -> tensor<4xi32> %1 = arith.muli %0, %0 : tensor<4xi32> - %2 = hal.tensor.export %1 into(%storage : !hal.buffer) : tensor<4xi32> -> !hal.buffer_view - util.return %2 : !hal.buffer_view + %2 = hal.tensor.alias %1 : tensor<4xi32> to %storage : !hal.buffer + %3 = hal.tensor.export %2 : tensor<4xi32> -> !hal.buffer_view + util.return %3 : !hal.buffer_view } // CHECK: util.global private @[[GLOBAL_ARG0:.+]] { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 40f63f811516..4a0c36adc322 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -474,18 +474,9 @@ LogicalResult TensorImportOp::verify() { void TensorExportOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value source, TypeAttr sourceEncoding, StringAttr name) { - build(builder, result, resultType, source, sourceEncoding, - /*targetStorage=*/nullptr, name); -} - -void TensorExportOp::build(OpBuilder &builder, OperationState &result, - Type resultType, Value source, - TypeAttr sourceEncoding, Value targetStorage, - StringAttr name) { auto dynamicDims = IREE::Util::buildDynamicDimsForValue(result.location, source, builder); - build(builder, result, resultType, source, sourceEncoding, dynamicDims, - targetStorage, name); + build(builder, result, resultType, source, sourceEncoding, dynamicDims, name); } Value TensorExportOp::getTiedResult(unsigned resultIndex) { @@ -512,6 +503,33 @@ LogicalResult TensorExportOp::verify() { op.getSource().getType()); } +//===----------------------------------------------------------------------===// +// hal.tensor.alias +//===----------------------------------------------------------------------===// + +Value TensorAliasOp::getTiedResult(unsigned resultIndex) { + return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource()); +} + +::std::optional +TensorAliasOp::getTiedResultOperandIndex(unsigned resultIndex) { + return {0}; // source +} + +SmallVector TensorAliasOp::getTiedResultOperandIndices() { + return {0}; // source +} + +LogicalResult TensorAliasOp::verify() { + TensorAliasOp op = *this; + auto type = llvm::cast(op.getSource().getType()); + if (type.getNumDynamicDims() != op.getSourceDims().size()) { + return op->emitOpError() + << "number of dynamic dims must match the operand type"; + } + return success(); +} + //===----------------------------------------------------------------------===// // hal.tensor.barrier //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 92894fec79df..32551f6ccf3c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -167,7 +167,6 @@ def HAL_TensorImportOp : HAL_PureOp<"tensor.import", [ } def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ - AttrSizedOperandSegments, DeclareOpInterfaceMethods>:$target_storage, OptionalAttr:$name ); let results = (outs @@ -206,7 +199,6 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ let assemblyFormat = [{ $source ($name^)? - (`into` `(` $target_storage^ `:` type($target_storage) `)`)? `:` custom($source_encoding, type($source)) (`{` $source_dims^ `}`)? `->` @@ -221,13 +213,6 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ "TypeAttr":$sourceEncoding, "StringAttr":$name )>, - OpBuilder<(ins - "Type":$resultType, - "Value":$source, - "TypeAttr":$sourceEncoding, - "Value":$targetStorage, - "StringAttr":$name - )>, ]; let extraClassDeclaration = [{ @@ -240,6 +225,76 @@ def HAL_TensorExportOp : HAL_PureOp<"tensor.export", [ let hasFolder = 1; } +// TODO(#17328): specify an allocation policy to control behavior. +def HAL_TensorAliasOp : HAL_PureOp<"tensor.alias", [ + AllTypesMatch<["source", "result"]>, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + Util_ShapeAwareOp, +]> { + let summary = [{hints that tensor storage should alias a HAL buffer view}]; + let description = [{ + Hints that the backing storage of an entire tensor aliases the given storage + buffer. There's no guarantee that the storage will alias and instead only + that the tensor contents will be written to the storage as if a copy had + occurred. This allows the compiler to avoid copies in the ideal case of a + producer that is able to produce directly into the target storage but still + handle cases where the producer is not able to be in-place. + + The storage buffer provided must have sufficient space for the tensor once + encoded. Dynamically shaped tensors may not consume the entire provided + storage. If a buffer view is provided the metadata is ignored and only the + backing buffer is used. + + An optional wait fence can be provided in cases where the storage is not + immediately available. Producers that may alias the storage will wait until + the storage is available before updating the contents. + + Explicit aliasing side-steps any analysis that may be performed by the + compiler and requires users to guarantee that the safety of the aliasing. + Copy-on-write, alias analysis for overlap detection, and ordering via + use-def chains are all ignorant of the aliased buffer memory and only ensure + the compiler consumes or produces the aliased memory consistent with itself. + + Example: + ```mlir + %init = tensor.empty + %value = linalg.generic ... outs(%init) + %aliased = hal.tensor.alias %value : tensor<...> to %buffer : !hal.buffer + ... linalg.generic ins(%aliased) ... + ``` + }]; + + let arguments = (ins + AnyTensor:$source, + HAL_ShapeDynamicDims:$source_dims, + AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$storage, + Optional:$wait_fence + ); + let results = (outs + AnyTensor:$result + ); + + let assemblyFormat = [{ + (`wait` `(` $wait_fence^ `)` `=` `` `>`)? + $source `:` type($source) (`{` $source_dims^ `}`)? + `to` + $storage `:` type($storage) + attr-dict + }]; + + let extraClassDeclaration = [{ + ValueRange getOperandDynamicDims(unsigned idx) { return getSourceDims(); } + ValueRange getResultDynamicDims(unsigned idx) { return getSourceDims(); } + }]; + + let hasVerifier = 1; +} + def HAL_TensorBarrierOp : HAL_Op<"tensor.barrier", [ AllTypesMatch<["sources", "results"]>, DeclareOpInterfaceMethods, %arg1: index) -> ! // ----- -// CHECK-LABEL: @tensorExportInPlace -util.func public @tensorExportInPlace(%arg0: tensor, %arg1: index, %arg2: !hal.buffer) -> !hal.buffer_view { - // CHECK: hal.tensor.export %arg0 into(%arg2 : !hal.buffer) : tensor as tensor{%arg1} -> !hal.buffer_view - %0 = hal.tensor.export %arg0 into(%arg2 : !hal.buffer) : tensor as tensor{%arg1} -> !hal.buffer_view - util.return %0 : !hal.buffer_view +// CHECK-LABEL: @tensorAlias +util.func public @tensorAlias(%arg0: tensor, %arg1: index, %arg2: !hal.buffer, %arg3: !hal.fence) -> tensor { + // CHECK: hal.tensor.alias wait(%arg3) => %arg0 : tensor{%arg1} to %arg2 : !hal.buffer + %0 = hal.tensor.alias wait(%arg3) => %arg0 : tensor{%arg1} to %arg2 : !hal.buffer + util.return %0 : tensor } // ----- diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp index fc1b7f181930..2acd3ba4f51c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/Patterns.cpp @@ -136,53 +136,106 @@ struct ConvertTensorExportOp auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); auto source = consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + + // Exporting a produced value - transfer our source value to an externally + // usable resource and directly export it. This will cause an allocation. + auto exportSource = adaptor.getSource(); auto externalType = rewriter.getType( IREE::Stream::Lifetime::External); - auto exportSource = adaptor.getSource(); - auto exportSize = source.resourceSize; - if (adaptor.getTargetStorage()) { - // Query the target storage buffer length; we will only populate up to - // what is required for the output. - auto storageSize = rewriter.createOrFold( - op.getLoc(), rewriter.getIndexType(), - TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), - affinityAttr); - - // Import the target storage as a resource that we can use as an update - // target. We overwrite the contents and just cast the storage to the - // target type so we know we can update it. - auto importOp = rewriter.create( - op.getLoc(), externalType, adaptor.getTargetStorage(), - TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize, - affinityAttr); - - // Copy the source value into the imported target storage. - auto zeroOffset = rewriter.create(op.getLoc(), 0); - auto updateOp = rewriter.create( - op.getLoc(), externalType, importOp.getResult(), - importOp.getResultSize(), zeroOffset, source.resourceSize, - source.resource, source.resourceSize, affinityAttr); - - // Export the updated resource. - // NOTE: the buffer size wrapped in the buffer view is the full size of - // the input buffer. This is so that we don't insert a data dependency on - // sparse operations or data-dependent dynamic shape dimensions. - exportSource = updateOp.getResult(); - exportSize = updateOp.getTargetSize(); - } else { - // Exporting a produced value - transfer our source value to an externally - // usable resource and directly export it. This will cause an allocation. - if (source.resource.getType() != externalType) { - exportSource = rewriter.create( - op.getLoc(), externalType, source.resource, source.resourceSize, - source.resourceSize, affinityAttr, affinityAttr); - } + if (source.resource.getType() != externalType) { + exportSource = rewriter.create( + op.getLoc(), externalType, source.resource, source.resourceSize, + source.resourceSize, affinityAttr, affinityAttr); } // Export (stream resource to buffer view). rewriter.replaceOpWithNewOp( op, targetType, exportSource, TypeAttr::get(sourceType), - adaptor.getSourceDims(), exportSize, affinityAttr); + adaptor.getSourceDims(), source.resourceSize, affinityAttr); + return success(); + } +}; + +// Imports the storage to alias as a resource, copies the source value into it, +// and slices out the source value. This should allow allocation placement to +// elide the update (and subsequently the slice) if possible and otherwise will +// turn into a copy. +// +// Effectively: +// %2 = hal.tensor.alias %0 : tensor<4xf32> to %1 : !hal.buffer_view +// -> +// %storage = stream.tensor.import %1 : !hal.buffer -> tensor<...> +// %update = stream.async.update %0, %storage[...] +// %2 = stream.async.slice %update[...] +struct ConvertTensorAliasOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IREE::HAL::TensorAliasOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = op.getSource().getType(); + auto source = + consumeTensorOperand(op.getLoc(), adaptor.getSource(), rewriter); + + // All operations (if any) will happen on the device specified by the alias + // as that indicates the affinity of the storage. + auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op); + + // Query the target storage buffer length; we will only populate up to + // what is required for the output. + auto storageSize = rewriter.createOrFold( + op.getLoc(), rewriter.getIndexType(), + TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(), + affinityAttr); + + // Import the target storage as a resource that we can use as an update + // target. We overwrite the contents and just cast the storage to the + // target type so we know we can update it. + auto externalType = rewriter.getType( + IREE::Stream::Lifetime::External); + auto importOp = rewriter.create( + op.getLoc(), externalType, adaptor.getStorage(), + TypeAttr::get(sourceType), adaptor.getSourceDims(), storageSize, + affinityAttr); + + // Await the fence, if needed. When not specified the storage is assumed to + // be immediately available. + Value storage = importOp.getResult(); + if (auto waitFence = op.getWaitFence()) { + Value waitTimepoint = rewriter.create( + op.getLoc(), rewriter.getType(), + ValueRange{waitFence}, affinityAttr); + storage = rewriter + .create( + op.getLoc(), ValueRange{storage}, + ValueRange{storageSize}, waitTimepoint) + .getResult(0); + } + + // Copy the source value into the imported target storage. + auto zeroOffset = rewriter.create(op.getLoc(), 0); + auto updateOp = rewriter.create( + op.getLoc(), externalType, storage, storageSize, zeroOffset, + source.resourceSize, source.resource, source.resourceSize, + affinityAttr); + + // Slice out the value from the updated tensor. + // This preserves the use-def chain but is almost always elided by aliasing + // the input value later on. + auto sliceOp = rewriter.create( + op.getLoc(), externalType, updateOp.getResult(), + updateOp.getTargetSize(), zeroOffset, source.resourceSize, + source.resourceSize, affinityAttr); + + // Transfer to match original lifetime (if needed). + Value result = sliceOp.getResult(); + if (source.resource.getType() != result.getType()) { + result = rewriter.create( + op.getLoc(), source.resource.getType(), result, source.resourceSize, + source.resourceSize, affinityAttr, affinityAttr); + } + rewriter.replaceOp(op, result); + return success(); } }; @@ -233,6 +286,7 @@ void populateHALToStreamConversionPatterns(MLIRContext *context, [](IREE::HAL::BufferViewType type) { return type; }); patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir index 8dc67055407c..c96596f28389 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/HALToStream/test/abi_ops.mlir @@ -69,32 +69,47 @@ util.func public @exportBufferView(%tensor: tensor, %dim0: index, %di // ----- -// CHECK-LABEL: @exportBufferViewInPlace -// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[STORAGE:.+]]: !hal.buffer) -util.func public @exportBufferViewInPlace(%tensor: tensor, %dim0: index, %dim1: index, %storage: !hal.buffer) -> !hal.buffer_view { - // CHECK: %[[STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]], %[[DIM1]]} : index - // CHECK-NEXT: %[[STORAGE_IMPORT:.+]] = stream.tensor.import %[[STORAGE]] - // CHECK-SAME: : !hal.buffer -> tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[STORAGE_SIZE]]} - // CHECK-NEXT: %[[STORAGE_UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[STORAGE_IMPORT]][%c0 to %[[SIZE]]] - // CHECK-SAME: : !stream.resource<*>{%[[SIZE]]} -> %[[STORAGE_IMPORT]] as !stream.resource{%[[STORAGE_SIZE]]} - // CHECK-NEXT: %[[STORAGE_RESULT:.+]] = stream.tensor.export %[[STORAGE_UPDATE]] : - // CHECK-SAME: tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[STORAGE_SIZE]]} - // CHECK-SAME: -> !hal.buffer_view - %0 = hal.tensor.export %tensor into(%storage : !hal.buffer) : tensor{%dim0, %dim1} -> !hal.buffer_view - // CHECK: util.return %[[STORAGE_RESULT]] - util.return %0 : !hal.buffer_view +// CHECK-LABEL: @aliasStorage +// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[STORAGE:.+]]: !hal.buffer) +util.func public @aliasStorage(%tensor: tensor, %dim0: index, %storage: !hal.buffer) -> tensor { + // CHECK: %[[MIN_STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]]} + // CHECK: %[[STORAGE_RESOURCE:.+]] = stream.tensor.import %[[STORAGE]] : !hal.buffer -> tensor{%[[DIM0]]} in !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[STORAGE_RESOURCE]][%c0 to %[[SIZE]]] : !stream.resource<*>{%[[SIZE]]} -> %[[STORAGE_RESOURCE]] as !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[SLICE:.+]] = stream.async.slice %[[UPDATE]][%c0 to %[[SIZE]]] : !stream.resource{%[[MIN_STORAGE_SIZE]]} -> !stream.resource{%[[SIZE]]} + // CHECK: %[[RESULT:.+]] = stream.async.transfer %[[SLICE]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} + %0 = hal.tensor.alias %tensor : tensor{%dim0} to %storage : !hal.buffer + // CHECK: util.return %[[RESULT]] + util.return %0 : tensor } // ----- -// As with @exportBufferViewInPlace above but using !hal.buffer_view storage. +// CHECK-LABEL: @aliasStorageAsync +// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[STORAGE:.+]]: !hal.buffer, %[[FENCE:.+]]: !hal.fence) +util.func public @aliasStorageAsync(%tensor: tensor, %dim0: index, %storage: !hal.buffer, %fence: !hal.fence) -> tensor { + // CHECK-DAG: %[[MIN_STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]]} + // CHECK-DAG: %[[UNREADY_STORAGE:.+]] = stream.tensor.import %[[STORAGE]] : !hal.buffer -> tensor{%[[DIM0]]} in !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK-DAG: %[[TIMEPOINT:.+]] = stream.timepoint.import %[[FENCE]] + // CHECK-DAG: %[[READY_STORAGE:.+]] = stream.timepoint.await %[[TIMEPOINT]] => %[[UNREADY_STORAGE]] : !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[READY_STORAGE]][%c0 to %[[SIZE]]] : !stream.resource<*>{%[[SIZE]]} -> %[[READY_STORAGE]] as !stream.resource{%[[MIN_STORAGE_SIZE]]} + // CHECK: %[[SLICE:.+]] = stream.async.slice %[[UPDATE]][%c0 to %[[SIZE]]] : !stream.resource{%[[MIN_STORAGE_SIZE]]} -> !stream.resource{%[[SIZE]]} + // CHECK: %[[RESULT:.+]] = stream.async.transfer %[[SLICE]] : !stream.resource{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]} + %0 = hal.tensor.alias wait(%fence) => %tensor : tensor{%dim0} to %storage : !hal.buffer + // CHECK: util.return %[[RESULT]] + util.return %0 : tensor +} -// CHECK-LABEL: @exportBufferViewInPlaceToView -// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[STORAGE:.+]]: !hal.buffer_view) -util.func public @exportBufferViewInPlaceToView(%tensor: tensor, %dim0: index, %dim1: index, %storage: !hal.buffer_view) -> !hal.buffer_view { - // CHECK: %[[STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor{%[[DIM0]], %[[DIM1]]} : index - // CHECK-NEXT: %[[STORAGE_IMPORT:.+]] = stream.tensor.import %[[STORAGE]] - // CHECK-SAME: : !hal.buffer_view -> tensor{%[[DIM0]], %[[DIM1]]} in !stream.resource{%[[STORAGE_SIZE]]} - %0 = hal.tensor.export %tensor into(%storage : !hal.buffer_view) : tensor{%dim0, %dim1} -> !hal.buffer_view - util.return %0 : !hal.buffer_view +// ----- + +// CHECK-LABEL: @tensorBarrier +// CHECK-SAME: (%[[TENSOR0:.+]]: !stream.resource<*>, %[[SIZE0:.+]]: index, %[[TENSOR1:.+]]: !stream.resource<*>, %[[SIZE1:.+]]: index, %[[FENCE:.+]]: !hal.fence) +util.func public @tensorBarrier(%tensor0: tensor<3xf32>, %tensor1: tensor, %fence: !hal.fence) -> (tensor<3xf32>, tensor) { + // CHECK-DAG: %[[TENSOR0_AFTER:.+]], %[[TENSOR0_BARRIER:.+]] = stream.timepoint.barrier %[[TENSOR0]] : !stream.resource<*>{%[[SIZE0]]} => !stream.timepoint + // CHECK-DAG: %[[TENSOR1_AFTER:.+]], %[[TENSOR1_BARRIER:.+]] = stream.timepoint.barrier %[[TENSOR1]] : !stream.resource<*>{%[[SIZE1]]} => !stream.timepoint + // CHECK-NEXT: %[[JOIN:.+]] = stream.timepoint.join max(%[[TENSOR0_BARRIER]], %[[TENSOR1_BARRIER]]) => !stream.timepoint + // CHECK-NEXT: stream.timepoint.chain_external %[[JOIN]] => (%[[FENCE]] : !hal.fence) + %0:2 = hal.tensor.barrier join(%tensor0, %tensor1 : tensor<3xf32>, tensor) => %fence : !hal.fence + // CHECK: util.return %[[TENSOR0_AFTER]], %[[SIZE0]], %[[TENSOR1_AFTER]], %[[SIZE1]] + util.return %0#0, %0#1 : tensor<3xf32>, tensor } + diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 3b0df17950d3..3982c4e2d49e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -1745,6 +1745,95 @@ OpFoldResult AsyncUpdateOp::fold(FoldAdaptor operands) { namespace { +// Detects updates that are overwriting entire tensors that could be folded into +// an in-place producer operation. +// +// Example: +// %1 = stream.async.dispatch ... %0 -> %0 +// %2 = stream.async.update %1, %0[full] +// -> +// %2 = stream.async.dispatch .... %0 -> %0 +struct ElideInPlaceUpdate : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AsyncUpdateOp updateOp, + PatternRewriter &rewriter) const override { + // Look for entire tensor replacement. + const bool isOverwrite = + updateOp.getUpdateSize() == updateOp.getTargetSize() && + updateOp.getUpdate().getType() == updateOp.getType(); + if (!isOverwrite) { + // Ignore partial updates. + return failure(); + } + + // Detect producers that are performing their operation in-place. + // This should be a global analysis to detect in-place operations across + // control flow/calls + auto updateOperand = + IREE::Util::TiedOpInterface::findTiedBaseValue(updateOp.getUpdate()); + auto targetOperand = + IREE::Util::TiedOpInterface::findTiedBaseValue(updateOp.getTarget()); + + // Condition: if overwriting the entire tied operand then this is a no-op. + if (updateOperand != targetOperand) { + return rewriter.notifyMatchFailure( + updateOp, + "update overwrite target is not tied to the producer operand"); + } + + // If there are multiple users we may need them for copies. We need to + // ensure that any uses of the produced update are reads - if not + // copy-on-write will require the update op to exist in order to identify + // the condition. + SmallVector accessRanges; + for (auto &use : updateOp.getUpdate().getUses()) { + // Ops that are async resource access aware let us see if the particular + // source update resource is written. Any tied ops will have their uses + // marked as writes so that we don't need to walk down all transitive + // users to detect writes. + if (auto accessOp = + dyn_cast(use.getOwner())) { + accessOp.getAsyncAccessRanges(accessRanges); + for (auto &accessRange : accessRanges) { + if (accessRange.resource == updateOp.getUpdate() && + !accessRange.isReadOnly()) { + // TODO(benvanik): allow non-overlapping writes by checking + // accessRange.mayOverlap - may not be worth it due to the update + // source being the entire resource. If we did this on copies as + // well we'd want that. + return rewriter.notifyMatchFailure( + updateOp, "usage writes the update resource and conservatively " + "blocks elision"); + } + } + accessRanges.clear(); + continue; + } + + // Use memory effect analysis for ops in other dialects that don't + // indicate their access ranges. We conservatively fail if the op doesn't + // declare memory effects. + std::optional> effects = + getEffectsRecursively(use.getOwner()); + if (!effects) { + // Effect analysis failed. + return rewriter.notifyMatchFailure( + updateOp, "usage has unknown memory effects and blocks elision"); + } + for (const MemoryEffects::EffectInstance &effect : *effects) { + if (isa(effect.getEffect())) { + // Write effect indicates something we can't analyze (today). + return rewriter.notifyMatchFailure( + updateOp, "usage has side-effects and blocks elision"); + } + } + } + + rewriter.replaceOp(updateOp, updateOp.getUpdate()); + return success(); + } +}; + // Turns a splat+update-from into a fill. // // Example: @@ -1816,6 +1905,7 @@ void AsyncUpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, // affinity/lifetime differ. // TODO(#6972): updates into splats could become alloca + fill exclusive // region + update into undefined contents (used in padding). + results.insert(context); results.insert(context); results.insert(context); results.insert>(context); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir index 64ae8f3a2591..155a857205ca 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir @@ -288,6 +288,68 @@ util.func private @FoldLocalAsyncUpdateOp(%arg0: !stream.resource<*>, %arg1: ind // ----- +// CHECK-LABEL: @ElideInPlaceUpdateUpdate +util.func private @ElideInPlaceUpdateUpdate(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.+]] = stream.async.update %arg0, %arg2[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg2 as !stream.resource<*>{%arg3} + %0 = stream.async.update %arg0, %arg2[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg2 as !stream.resource<*>{%arg3} + // CHECK-NOT: stream.async.update + %1 = stream.async.update %0, %arg2[%c0 to %arg3] : !stream.resource<*>{%arg3} -> %arg2 as !stream.resource<*>{%arg3} + // CHECK: util.return %[[RESULT]] + util.return %1 : !stream.resource<*> +} + +// ----- + +// CHECK-LABEL: @ElideInPlaceUpdateDispatch +util.func private @ElideInPlaceUpdateDispatch(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> { + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@fn(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + %0 = stream.async.dispatch @ex::@fn(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + // CHECK-NOT: stream.async.update + %1 = stream.async.update %0, %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg0 as !stream.resource<*>{%arg1} + // CHECK: util.return %[[RESULT]] + util.return %1 : !stream.resource<*> +} + +// ----- + +// Tests that multiple users of the produced value will still allow the update +// to be elided so long as they are reads. + +// CHECK-LABEL: @ElideInPlaceUpdateDispatchMultiUse +util.func private @ElideInPlaceUpdateDispatchMultiUse(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>) { + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT0:.+]] = stream.async.dispatch @ex::@fn0(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + %0 = stream.async.dispatch @ex::@fn0(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + // CHECK-NOT: stream.async.update + %1 = stream.async.update %0, %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg0 as !stream.resource<*>{%arg1} + // CHECK: %[[RESULT1:.+]] = stream.async.dispatch @ex::@fn1(%[[RESULT0]][%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} + %2 = stream.async.dispatch @ex::@fn1(%0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1} + // CHECK: util.return %[[RESULT0]], %[[RESULT1]] + util.return %1, %2 : !stream.resource<*>, !stream.resource<*> +} + +// ----- + +// Tests that writes on the update source will fail to elide the update. +// TODO(benvanik): support looking for writes only prior to the update that are +// known-safe. + +// CHECK-LABEL: @ElideInPlaceUpdateDispatchMultiUseWrite +util.func private @ElideInPlaceUpdateDispatchMultiUseWrite(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>) { + %c0 = arith.constant 0 : index + // CHECK: stream.async.dispatch @ex::@fn0 + %0 = stream.async.dispatch @ex::@fn0(%arg0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %arg0{%arg1} + // CHECK: stream.async.update + %1 = stream.async.update %0, %arg0[%c0 to %arg1] : !stream.resource<*>{%arg1} -> %arg0 as !stream.resource<*>{%arg1} + // CHECK: stream.async.dispatch @ex::@fn1 + %2 = stream.async.dispatch @ex::@fn1(%0[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> %0{%arg1} + util.return %1, %2 : !stream.resource<*>, !stream.resource<*> +} + +// ----- + // CHECK-LABEL: @CombineSplatUpdateFromToFill util.func private @CombineSplatUpdateFromToFill(%arg0: !stream.resource<*>, %arg1: index) -> !stream.resource<*> { %c0 = arith.constant 0 : index diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp index 3d0d28e9d6e2..fcdb15f91635 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp @@ -206,11 +206,15 @@ bool tryMoveProducerBefore(Value value, Operation *consumerOp) { // satisfies the request based on SSA dominance. return true; } + + // Recursively try to move each operand. + // TODO(benvanik): change to a worklist to avoid potential stack explosion. for (auto operand : producerOp->getOperands()) { - if (!isValueUsableForOp(operand, consumerOp)) { + if (!tryMoveProducerBefore(operand, consumerOp)) { return false; } } + producerOp->moveBefore(consumerOp); return true; } diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index dececb2c6db0..1c3664803ff7 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -244,7 +244,7 @@ class TensorExportPattern rewriter.replaceOpWithNewOp( srcOp, resultType, adaptor.getSource(), TypeAttr::get(adaptor.getSource().getType()), adaptor.getSourceDims(), - /*target_storage=*/nullptr, /*name=*/nullptr); + /*name=*/nullptr); } return success(); }