From 5f7b471f88e9c8e145637bc1d2f7a4c1d8b4462e Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 6 Feb 2025 09:27:50 -0800 Subject: [PATCH] [Stream] Add layouts to encodings for all stream tensor AffinityOp. (#19726) The revision adds the support for the rest of AffinityOp that have TensorPhase trait, i.e., TensorCloneOp, TensorSliceOp, TensorFillOp, and TensorUpdateOp ops. It is tricky to handle encodings for transfer ops, so only the encoding in the fill op is updated. If other operations have tensor encodings, it returns a failure for now. There are two stream tensor ops do not implement the AffinityOpInterface, so they are not supported within the revision. They are stream.tensor.load op and stream.tensor.store op. We should be able to track the resource affinity for these two ops, and it requires additional analysis. Thus, they are not scoped within the revision. The revision also adds the missing documentation to the `addLayoutsToTensorPhaseOps` method. --------- Signed-off-by: hanhanW --- .../Stream/Transforms/SpecializeEncodings.cpp | 101 +++++++++++++++++- .../Transforms/test/specialize_encodings.mlir | 99 ++++++++++++++++- 2 files changed, 197 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp index 1fbaedbbd778..92310ab34721 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp @@ -358,6 +358,22 @@ updateTensorSizeOfOp(RewriterBase &rewriter, return success(); } +/// Updates the target encoding of `op` with resolved layouts. +static LogicalResult +updateTensorFillOp(RewriterBase &rewriter, IREE::Stream::TensorFillOp op, + const SetVector &layoutResolvers) { + auto encodingType = dyn_cast(op.getTargetEncoding()); + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType, layoutResolvers); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(op, [&] { + op.setTargetEncoding(cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); +} + /// Returns failure if `op` has encoding. The EncodingAttr has padding /// semantic, a constant op with such encoding can not be resolved at this /// moment. @@ -375,7 +391,70 @@ updateTensorConstantOp(RewriterBase &rewriter, return success(); } -/// Updates the result_encoding for `op`. The op have to define a +/// Returns a failure if there are encodings in target encoding type or update +/// encoding type. +static LogicalResult updateTensorUpdateOp(RewriterBase &rewriter, + IREE::Stream::TensorUpdateOp op) { + auto targetEncodingType = dyn_cast(op.getTargetEncoding()); + if (targetEncodingType && targetEncodingType.getEncoding()) { + return failure(); + } + auto updateEncodingType = dyn_cast(op.getUpdateEncoding()); + if (updateEncodingType && updateEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Returns a failure if there are encodings in source encoding type or result +/// encoding type. +static LogicalResult updateTensorCloneOp(RewriterBase &rewriter, + IREE::Stream::TensorCloneOp op) { + auto sourceEncodingType = dyn_cast(op.getSourceEncoding()); + if (sourceEncodingType && sourceEncodingType.getEncoding()) { + return failure(); + } + auto resultEncodingType = dyn_cast(op.getResultEncoding()); + if (resultEncodingType && resultEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Returns a failure if there are encodings in source encoding type or result +/// encoding type. +static LogicalResult updateTensorSliceOp(RewriterBase &rewriter, + IREE::Stream::TensorSliceOp op) { + auto sourceEncodingType = dyn_cast(op.getSourceEncoding()); + if (sourceEncodingType && sourceEncodingType.getEncoding()) { + return failure(); + } + auto resultEncodingType = dyn_cast(op.getResultEncoding()); + if (resultEncodingType && resultEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Updates the source_encoding for `op`. The op has to define a +/// `source_encoding` parameter. +template +static LogicalResult +updateSourceEncoding(RewriterBase &rewriter, OpTy op, + const SetVector &layoutResolvers) { + auto encodingType = dyn_cast(op.getSourceEncoding()); + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType, layoutResolvers); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(op, [&] { + op.setSourceEncoding(cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); +} + +/// Updates the result_encoding for `op`. The op has to define a /// `result_encoding` parameter. template static LogicalResult @@ -393,6 +472,16 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op, return success(); } +/// Adds the resolved layouts to all tensor types on stream tensor ops, if +/// encodings are present. Most of stream tensor ops implement +/// AffinityOpInterface, where a stream affinity indicates the kind of +/// enviroment the ops are expected run in. When an encoding is present in the +/// tensor type, the method resolves the layouts, strips outdated information, +/// and adds the resolved layouts to the encodings. The updated encodings should +/// have enough information for other lowering transformations. +/// TODO(hanchung): Add support for stream.tensor.load ops and +/// stream.tensor.store ops. They are not affinity ops, so additional analysis +/// will be needed in the work. static LogicalResult addLayoutsToTensorPhaseOps( ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis, FunctionOpInterface funcOp, @@ -424,7 +513,6 @@ static LogicalResult addLayoutsToTensorPhaseOps( return affinityOp.emitError("failed on making layout resolvers"); } - // TODO(hanchung): Update other Stream operations. LogicalResult result = TypeSwitch(affinityOp) .Case([&](auto op) { @@ -442,6 +530,15 @@ static LogicalResult addLayoutsToTensorPhaseOps( .Case([&](auto op) { return updateTensorConstantOp(rewriter, op, layoutResolvers); }) + .Case([&](auto op) { + return updateTensorFillOp(rewriter, op, layoutResolvers); + }) + .Case( + [&](auto op) { return updateTensorCloneOp(rewriter, op); }) + .Case( + [&](auto op) { return updateTensorSliceOp(rewriter, op); }) + .Case( + [&](auto op) { return updateTensorUpdateOp(rewriter, op); }) .Default([](Operation *op) { return op->emitOpError("Unhandled stream op"); }); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir index 2c71a86e1639..f57664bcce95 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir @@ -65,12 +65,39 @@ module { // ----- +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + util.func public @tensor_fill_op(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = stream.tensor.fill on(#hal.device.affinity<@device_a>) + %arg0, %arg1[%c0, %c0 for %c1, %c1] : f32 + -> tensor{%arg2} in %arg1 as !stream.resource<*>{%arg3} + util.return + } +} +// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<{{.+}} layouts = [#iree_encoding.specialized_encoding<123, tensor>] +// CHECK: #[[TARGET:.+]] = #hal.device.target +// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]] +// CHECK-LABEL: util.func public @tensor_fill_op +// CHECK: stream.tensor.fill on(#hal.device.affinity<@[[$DEVICE]]>) +// CHECK-SAME: f32 -> tensor + +// ----- + // Checks that the stream.tensor.constant op with encoding is not supported. #map0 = affine_map<(m, n, k) -> (m, k)> #map1 = affine_map<(m, n, k) -> (k, n)> #map2 = affine_map<(m, n, k) -> (m, n)> -#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device #encoding = #iree_encoding.encoding module { @@ -85,6 +112,76 @@ module { // ----- +// Checks that the stream.tensor.clone op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_clone_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %0 = stream.tensor.clone on(#hal.device.affinity<@device_a>) + %arg0 : tensor{%arg1} in !stream.resource<*>{%arg2} + -> tensor{%arg1} in !stream.resource<*>{%arg2} + util.return + } +} + +// ----- + +// Checks that the stream.tensor.slice op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_slice_op_with_encoding(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = stream.tensor.slice on(#hal.device.affinity<@device_a>) + %arg0[%c0, %c1 for %arg3, %c1] : tensor{%arg1} in !stream.resource<*>{%arg2} + -> tensor{%arg3} in !stream.resource<*>{%arg4} + util.return + } +} + +// ----- + +// Checks that the stream.tensor.update op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = stream.tensor.update on(#hal.device.affinity<@device_a>) + %arg0, %arg2[%c0, %c0] : tensor<2x2xf32, #encoding> in !stream.resource<*>{%arg1} + -> tensor{%arg3} in %arg2 as !stream.resource<*>{%arg4} + util.return + } +} + +// ----- + #executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #map = affine_map<(d0) -> (d0)> #map0 = affine_map<(m, n, k) -> (m, k)>