diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp index 1fbaedbbd778..92310ab34721 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp @@ -358,6 +358,22 @@ updateTensorSizeOfOp(RewriterBase &rewriter, return success(); } +/// Updates the target encoding of `op` with resolved layouts. +static LogicalResult +updateTensorFillOp(RewriterBase &rewriter, IREE::Stream::TensorFillOp op, + const SetVector &layoutResolvers) { + auto encodingType = dyn_cast(op.getTargetEncoding()); + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType, layoutResolvers); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(op, [&] { + op.setTargetEncoding(cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); +} + /// Returns failure if `op` has encoding. The EncodingAttr has padding /// semantic, a constant op with such encoding can not be resolved at this /// moment. @@ -375,7 +391,70 @@ updateTensorConstantOp(RewriterBase &rewriter, return success(); } -/// Updates the result_encoding for `op`. The op have to define a +/// Returns a failure if there are encodings in target encoding type or update +/// encoding type. +static LogicalResult updateTensorUpdateOp(RewriterBase &rewriter, + IREE::Stream::TensorUpdateOp op) { + auto targetEncodingType = dyn_cast(op.getTargetEncoding()); + if (targetEncodingType && targetEncodingType.getEncoding()) { + return failure(); + } + auto updateEncodingType = dyn_cast(op.getUpdateEncoding()); + if (updateEncodingType && updateEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Returns a failure if there are encodings in source encoding type or result +/// encoding type. +static LogicalResult updateTensorCloneOp(RewriterBase &rewriter, + IREE::Stream::TensorCloneOp op) { + auto sourceEncodingType = dyn_cast(op.getSourceEncoding()); + if (sourceEncodingType && sourceEncodingType.getEncoding()) { + return failure(); + } + auto resultEncodingType = dyn_cast(op.getResultEncoding()); + if (resultEncodingType && resultEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Returns a failure if there are encodings in source encoding type or result +/// encoding type. +static LogicalResult updateTensorSliceOp(RewriterBase &rewriter, + IREE::Stream::TensorSliceOp op) { + auto sourceEncodingType = dyn_cast(op.getSourceEncoding()); + if (sourceEncodingType && sourceEncodingType.getEncoding()) { + return failure(); + } + auto resultEncodingType = dyn_cast(op.getResultEncoding()); + if (resultEncodingType && resultEncodingType.getEncoding()) { + return failure(); + } + return success(); +} + +/// Updates the source_encoding for `op`. The op has to define a +/// `source_encoding` parameter. +template +static LogicalResult +updateSourceEncoding(RewriterBase &rewriter, OpTy op, + const SetVector &layoutResolvers) { + auto encodingType = dyn_cast(op.getSourceEncoding()); + std::optional encodingAttr = + getEncodingWithNewLayouts(encodingType, layoutResolvers); + if (!encodingAttr) { + return success(); + } + rewriter.modifyOpInPlace(op, [&] { + op.setSourceEncoding(cloneWithEncoding(encodingType, encodingAttr.value())); + }); + return success(); +} + +/// Updates the result_encoding for `op`. The op has to define a /// `result_encoding` parameter. template static LogicalResult @@ -393,6 +472,16 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op, return success(); } +/// Adds the resolved layouts to all tensor types on stream tensor ops, if +/// encodings are present. Most of stream tensor ops implement +/// AffinityOpInterface, where a stream affinity indicates the kind of +/// enviroment the ops are expected run in. When an encoding is present in the +/// tensor type, the method resolves the layouts, strips outdated information, +/// and adds the resolved layouts to the encodings. The updated encodings should +/// have enough information for other lowering transformations. +/// TODO(hanchung): Add support for stream.tensor.load ops and +/// stream.tensor.store ops. They are not affinity ops, so additional analysis +/// will be needed in the work. static LogicalResult addLayoutsToTensorPhaseOps( ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis, FunctionOpInterface funcOp, @@ -424,7 +513,6 @@ static LogicalResult addLayoutsToTensorPhaseOps( return affinityOp.emitError("failed on making layout resolvers"); } - // TODO(hanchung): Update other Stream operations. LogicalResult result = TypeSwitch(affinityOp) .Case([&](auto op) { @@ -442,6 +530,15 @@ static LogicalResult addLayoutsToTensorPhaseOps( .Case([&](auto op) { return updateTensorConstantOp(rewriter, op, layoutResolvers); }) + .Case([&](auto op) { + return updateTensorFillOp(rewriter, op, layoutResolvers); + }) + .Case( + [&](auto op) { return updateTensorCloneOp(rewriter, op); }) + .Case( + [&](auto op) { return updateTensorSliceOp(rewriter, op); }) + .Case( + [&](auto op) { return updateTensorUpdateOp(rewriter, op); }) .Default([](Operation *op) { return op->emitOpError("Unhandled stream op"); }); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir index 2c71a86e1639..f57664bcce95 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir @@ -65,12 +65,39 @@ module { // ----- +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + util.func public @tensor_fill_op(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = stream.tensor.fill on(#hal.device.affinity<@device_a>) + %arg0, %arg1[%c0, %c0 for %c1, %c1] : f32 + -> tensor{%arg2} in %arg1 as !stream.resource<*>{%arg3} + util.return + } +} +// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<{{.+}} layouts = [#iree_encoding.specialized_encoding<123, tensor>] +// CHECK: #[[TARGET:.+]] = #hal.device.target +// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]] +// CHECK-LABEL: util.func public @tensor_fill_op +// CHECK: stream.tensor.fill on(#hal.device.affinity<@[[$DEVICE]]>) +// CHECK-SAME: f32 -> tensor + +// ----- + // Checks that the stream.tensor.constant op with encoding is not supported. #map0 = affine_map<(m, n, k) -> (m, k)> #map1 = affine_map<(m, n, k) -> (k, n)> #map2 = affine_map<(m, n, k) -> (m, n)> -#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device #encoding = #iree_encoding.encoding module { @@ -85,6 +112,76 @@ module { // ----- +// Checks that the stream.tensor.clone op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_clone_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %0 = stream.tensor.clone on(#hal.device.affinity<@device_a>) + %arg0 : tensor{%arg1} in !stream.resource<*>{%arg2} + -> tensor{%arg1} in !stream.resource<*>{%arg2} + util.return + } +} + +// ----- + +// Checks that the stream.tensor.slice op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_slice_op_with_encoding(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %1 = stream.tensor.slice on(#hal.device.affinity<@device_a>) + %arg0[%c0, %c1 for %arg3, %c1] : tensor{%arg1} in !stream.resource<*>{%arg2} + -> tensor{%arg3} in !stream.resource<*>{%arg4} + util.return + } +} + +// ----- + +// Checks that the stream.tensor.update op with encoding is not supported. + +#map0 = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> +#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device +#encoding = #iree_encoding.encoding +module { + util.global private @device_a = #device_target_local_0_ + + // expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}} + util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = stream.tensor.update on(#hal.device.affinity<@device_a>) + %arg0, %arg2[%c0, %c0] : tensor<2x2xf32, #encoding> in !stream.resource<*>{%arg1} + -> tensor{%arg3} in %arg2 as !stream.resource<*>{%arg4} + util.return + } +} + +// ----- + #executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}> #map = affine_map<(d0) -> (d0)> #map0 = affine_map<(m, n, k) -> (m, k)>