Skip to content

Commit

Permalink
[Stream] Add layouts to encodings for all stream tensor AffinityOp. (#…
Browse files Browse the repository at this point in the history
…19726)

The revision adds the support for the rest of AffinityOp that have
TensorPhase trait, i.e., TensorCloneOp, TensorSliceOp, TensorFillOp, and
TensorUpdateOp ops. It is tricky to handle encodings for transfer ops,
so only the encoding in the fill op is updated. If other operations have
tensor encodings, it returns a failure for now.

There are two stream tensor ops do not implement the
AffinityOpInterface, so they are not supported within the revision. They
are stream.tensor.load op and stream.tensor.store op. We should be able
to track the resource affinity for these two ops, and it requires
additional analysis. Thus, they are not scoped within the revision.

The revision also adds the missing documentation to the
`addLayoutsToTensorPhaseOps` method.

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Feb 6, 2025
1 parent ac46df5 commit 5f7b471
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,22 @@ updateTensorSizeOfOp(RewriterBase &rewriter,
return success();
}

/// Updates the target encoding of `op` with resolved layouts.
static LogicalResult
updateTensorFillOp(RewriterBase &rewriter, IREE::Stream::TensorFillOp op,
const SetVector<Attribute> &layoutResolvers) {
auto encodingType = dyn_cast<RankedTensorType>(op.getTargetEncoding());
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(encodingType, layoutResolvers);
if (!encodingAttr) {
return success();
}
rewriter.modifyOpInPlace(op, [&] {
op.setTargetEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
});
return success();
}

/// Returns failure if `op` has encoding. The EncodingAttr has padding
/// semantic, a constant op with such encoding can not be resolved at this
/// moment.
Expand All @@ -375,7 +391,70 @@ updateTensorConstantOp(RewriterBase &rewriter,
return success();
}

/// Updates the result_encoding for `op`. The op have to define a
/// Returns a failure if there are encodings in target encoding type or update
/// encoding type.
static LogicalResult updateTensorUpdateOp(RewriterBase &rewriter,
IREE::Stream::TensorUpdateOp op) {
auto targetEncodingType = dyn_cast<RankedTensorType>(op.getTargetEncoding());
if (targetEncodingType && targetEncodingType.getEncoding()) {
return failure();
}
auto updateEncodingType = dyn_cast<RankedTensorType>(op.getUpdateEncoding());
if (updateEncodingType && updateEncodingType.getEncoding()) {
return failure();
}
return success();
}

/// Returns a failure if there are encodings in source encoding type or result
/// encoding type.
static LogicalResult updateTensorCloneOp(RewriterBase &rewriter,
IREE::Stream::TensorCloneOp op) {
auto sourceEncodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
if (sourceEncodingType && sourceEncodingType.getEncoding()) {
return failure();
}
auto resultEncodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
if (resultEncodingType && resultEncodingType.getEncoding()) {
return failure();
}
return success();
}

/// Returns a failure if there are encodings in source encoding type or result
/// encoding type.
static LogicalResult updateTensorSliceOp(RewriterBase &rewriter,
IREE::Stream::TensorSliceOp op) {
auto sourceEncodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
if (sourceEncodingType && sourceEncodingType.getEncoding()) {
return failure();
}
auto resultEncodingType = dyn_cast<RankedTensorType>(op.getResultEncoding());
if (resultEncodingType && resultEncodingType.getEncoding()) {
return failure();
}
return success();
}

/// Updates the source_encoding for `op`. The op has to define a
/// `source_encoding` parameter.
template <typename OpTy>
static LogicalResult
updateSourceEncoding(RewriterBase &rewriter, OpTy op,
const SetVector<Attribute> &layoutResolvers) {
auto encodingType = dyn_cast<RankedTensorType>(op.getSourceEncoding());
std::optional<IREE::Encoding::EncodingAttr> encodingAttr =
getEncodingWithNewLayouts(encodingType, layoutResolvers);
if (!encodingAttr) {
return success();
}
rewriter.modifyOpInPlace(op, [&] {
op.setSourceEncoding(cloneWithEncoding(encodingType, encodingAttr.value()));
});
return success();
}

/// Updates the result_encoding for `op`. The op has to define a
/// `result_encoding` parameter.
template <typename OpTy>
static LogicalResult
Expand All @@ -393,6 +472,16 @@ updateResultEncoding(RewriterBase &rewriter, OpTy op,
return success();
}

/// Adds the resolved layouts to all tensor types on stream tensor ops, if
/// encodings are present. Most of stream tensor ops implement
/// AffinityOpInterface, where a stream affinity indicates the kind of
/// enviroment the ops are expected run in. When an encoding is present in the
/// tensor type, the method resolves the layouts, strips outdated information,
/// and adds the resolved layouts to the encodings. The updated encodings should
/// have enough information for other lowering transformations.
/// TODO(hanchung): Add support for stream.tensor.load ops and
/// stream.tensor.store ops. They are not affinity ops, so additional analysis
/// will be needed in the work.
static LogicalResult addLayoutsToTensorPhaseOps(
ModuleOp moduleOp, IREE::Stream::AffinityAnalysis &affinityAnalysis,
FunctionOpInterface funcOp,
Expand Down Expand Up @@ -424,7 +513,6 @@ static LogicalResult addLayoutsToTensorPhaseOps(
return affinityOp.emitError("failed on making layout resolvers");
}

// TODO(hanchung): Update other Stream operations.
LogicalResult result =
TypeSwitch<Operation *, LogicalResult>(affinityOp)
.Case<IREE::Stream::TensorDispatchOp>([&](auto op) {
Expand All @@ -442,6 +530,15 @@ static LogicalResult addLayoutsToTensorPhaseOps(
.Case<IREE::Stream::TensorConstantOp>([&](auto op) {
return updateTensorConstantOp(rewriter, op, layoutResolvers);
})
.Case<IREE::Stream::TensorFillOp>([&](auto op) {
return updateTensorFillOp(rewriter, op, layoutResolvers);
})
.Case<IREE::Stream::TensorCloneOp>(
[&](auto op) { return updateTensorCloneOp(rewriter, op); })
.Case<IREE::Stream::TensorSliceOp>(
[&](auto op) { return updateTensorSliceOp(rewriter, op); })
.Case<IREE::Stream::TensorUpdateOp>(
[&](auto op) { return updateTensorUpdateOp(rewriter, op); })
.Default([](Operation *op) {
return op->emitOpError("Unhandled stream op");
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,39 @@ module {

// -----

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

util.func public @tensor_fill_op(%arg0: f32, %arg1: !stream.resource<*>, %arg2: index, %arg3: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = stream.tensor.fill on(#hal.device.affinity<@device_a>)
%arg0, %arg1[%c0, %c0 for %c1, %c1] : f32
-> tensor<?x4xf32, #encoding>{%arg2} in %arg1 as !stream.resource<*>{%arg3}
util.return
}
}
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<{{.+}} layouts = [#iree_encoding.specialized_encoding<123, tensor<?x4xf32>>]
// CHECK: #[[TARGET:.+]] = #hal.device.target
// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]]
// CHECK-LABEL: util.func public @tensor_fill_op
// CHECK: stream.tensor.fill on(#hal.device.affinity<@[[$DEVICE]]>)
// CHECK-SAME: f32 -> tensor<?x4xf32, #[[$ENCODING]]>

// -----

// Checks that the stream.tensor.constant op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_cpu.vmvx_encoding_layout<>}>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
Expand All @@ -85,6 +112,76 @@ module {

// -----

// Checks that the stream.tensor.clone op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
util.func public @tensor_clone_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%0 = stream.tensor.clone on(#hal.device.affinity<@device_a>)
%arg0 : tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
-> tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
util.return
}
}

// -----

// Checks that the stream.tensor.slice op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
util.func public @tensor_slice_op_with_encoding(%arg0: !stream.resource<*>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = stream.tensor.slice on(#hal.device.affinity<@device_a>)
%arg0[%c0, %c1 for %arg3, %c1] : tensor<?x4xf32, #encoding>{%arg1} in !stream.resource<*>{%arg2}
-> tensor<?x1xf32, #encoding>{%arg3} in !stream.resource<*>{%arg4}
util.return
}
}

// -----

// Checks that the stream.tensor.update op with encoding is not supported.

#map0 = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map0, #map1, #map2]>
module {
util.global private @device_a = #device_target_local_0_

// expected-error @+1 {{failed on adding layouts to Stream::TensorPhaseOp with encodings}}
util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %arg2: !stream.resource<*>, %arg3: index, %arg4: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = stream.tensor.update on(#hal.device.affinity<@device_a>)
%arg0, %arg2[%c0, %c0] : tensor<2x2xf32, #encoding> in !stream.resource<*>{%arg1}
-> tensor<?x4xf32, #encoding>{%arg3} in %arg2 as !stream.resource<*>{%arg4}
util.return
}
}

// -----

#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#map = affine_map<(d0) -> (d0)>
#map0 = affine_map<(m, n, k) -> (m, k)>
Expand Down

0 comments on commit 5f7b471

Please sign in to comment.