Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stream] Add support for executable duplication for tied operand cases. #19953

Merged
merged 11 commits into from
Feb 14, 2025
34 changes: 34 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,36 @@ verifyDispatchWorkload(Operation *op, IREE::Stream::ExecutableExportOp exportOp,
return success();
}

// Verifies the tied operand types are as the same as the result types.
static LogicalResult verifyTiedOperandEncodings(Operation *op,
ArrayAttr operandEncodingsAttr,
ArrayAttr resultEncodingsAttr) {
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (!tiedOp) {
return op->emitOpError()
<< "the op does not implement IREE::Util::TiedOpInterface";
}

ArrayRef<Attribute> operandEncodings = operandEncodingsAttr.getValue();
unsigned tiedOperandBase = tiedOp.getTiedOperandsIndexAndLength().first;
for (auto [idx, resEncoding] :
hanhanW marked this conversation as resolved.
Show resolved Hide resolved
llvm::enumerate(resultEncodingsAttr.getValue())) {
auto tiedOperand = tiedOp.getTiedResultOperandIndex(idx);
if (!tiedOperand.has_value()) {
continue;
}
auto operandIndex = tiedOperand.value() - tiedOperandBase;
if (operandEncodings[operandIndex] != resEncoding) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benvanik I might do something wrong here. Do you mean to check the "type" or the encoding attribute?

return op->emitError()
<< "the " << operandIndex << "-th operandEncoding ("
<< operandEncodings[operandIndex]
<< ") does not match the resultEncoding (" << resEncoding << ")";
}
}

return success();
}

// Verifies that |dynamicDims| contains the appropriate number of dims for all
// the dynamic dimensions in |type|.
static LogicalResult verifyOpDynamicDims(Operation *op, TypeRange types,
Expand Down Expand Up @@ -2112,6 +2142,10 @@ LogicalResult TensorDispatchOp::verify() {
op.getResultEncodingDims()))) {
return failure();
}
if (failed(verifyTiedOperandEncodings(op, op.getOperandEncodings(),
op.getResultEncodings()))) {
return failure();
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
// RUN: iree-opt --split-input-file %s --verify-diagnostics | FileCheck %s

// CHECK-LABEL: @tensorImport
util.func private @tensorImport(%arg0: !hal.buffer_view, %arg1: index) -> !stream.resource<external> {
Expand Down Expand Up @@ -162,7 +162,20 @@ util.func private @tensorDispatch(%arg0: !stream.resource<*>, %arg1: index, %arg
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
// CHECK: = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) :
// CHECK-SAME: (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<?x4xf32>{%arg2} in %arg0{%arg1}
// CHECK-SAME: (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32>{%arg2} in %arg0{%arg1}
%0 = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32>{%arg2} in %arg0{%arg1}
util.return %0 : !stream.resource<*>
}

// -----

util.func private @tensorDispatchMismatch(%arg0: !stream.resource<*>, %arg1: index, %arg2: index) -> !stream.resource<*> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
// expected-error @+1 {{the 0-th operandEncoding (tensor<4x?xf32>) does not match the resultEncoding (tensor<?x4xf32>)}}
%0 = stream.tensor.dispatch @executable::@dispatch[%c1, %c2, %c3](%arg0, %c4) : (tensor<4x?xf32>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<?x4xf32>{%arg2} in %arg0{%arg1}
util.return %0 : !stream.resource<*>
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,53 @@ updateBindingEncodings(FunctionOpInterface funcOp,
return success();
}

/// Returns the operands encodings and result encodings from the `dispatchOp` in
/// |operands| + |results| order, i.e., it returns the stripped concatenated
/// operand encodings and result encodings. If a result is tied to an operand,
/// the result encoding is skipped. Because it shares the same binding with the
/// tied operands.
///
/// Example 1:
///
/// %0 = stream.tensor.dispatch ...(%arg0, %c4)
/// : (tensor<4x?xf32, #encoding> in !resource, index)
/// -> tensor<4x?xf32, #encoding> in !resource
///
/// The above dispatch op does not have tied operands. Thus, it returns
/// |#resolved_encoding, whatever_without_encoding, #resolved_encoding|
///
/// Example 2:
///
/// %0 = stream.tensor.dispatch ...(%arg0, %c4) : tensor<4x?xf32, #encoding>
/// -> tensor<4x?xf32, #encoding> in %arg0
///
/// The above dispatch op ties the result to the first operand. Thus, the result
/// encoding is stripped. It returns
/// |#resolved_encoding, whatever_without_encoding|
static SmallVector<Attribute>
getBindingLayoutAttrs(IREE::Stream::TensorDispatchOp dispatchOp) {
SmallVector<int64_t> tiedOperands(dispatchOp.getNumResults(),
IREE::Util::TiedOpInterface::kUntiedIndex);
if (std::optional<ArrayAttr> tiedOperandsAttr =
dispatchOp.getTiedOperands()) {
tiedOperands =
llvm::map_to_vector(tiedOperandsAttr.value(), [](Attribute intAttr) {
return llvm::cast<IntegerAttr>(intAttr).getInt();
});
}

SmallVector<Attribute> result(dispatchOp.getOperandEncodings().getValue());
for (auto [resultEncoding, tiedOperand] : llvm::zip_equal(
dispatchOp.getResultEncodings().getValue(), tiedOperands)) {
if (tiedOperand != IREE::Util::TiedOpInterface::kUntiedIndex) {
continue;
}
result.push_back(resultEncoding);
}

return result;
}

/// Duplicates stream.executables based on the operand encodings and result
/// encodings of stream.tensor.dispatch ops. Some executables can be launched by
/// different devices. It can produce wrong codegen artifacts when bindings
Expand Down Expand Up @@ -175,10 +222,8 @@ duplicateExecutablesPerLayoutVariant(ModuleOp moduleOp, SymbolTable symbolTable,
llvm::MapVector<IREE::Stream::TensorDispatchOp, SmallVector<Attribute>>
dispatchOpBindingLayouts;
for (auto dispatchOp : candidates) {
SmallVector<Attribute> bindingLayoutAttrs(
dispatchOp.getOperandEncodings().getValue());
llvm::append_range(bindingLayoutAttrs,
dispatchOp.getResultEncodings().getValue());
SmallVector<Attribute> bindingLayoutAttrs =
getBindingLayoutAttrs(dispatchOp);
dispatchOpBindingLayouts[dispatchOp] = bindingLayoutAttrs;
dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPoint) {
auto exportOp = cast<IREE::Stream::ExecutableExportOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ util.func public @denseTensorDispatch(
// CHECK-SAME: %[[RESOURCE1]][%[[ZERO]] to %[[RESOURCE1_SIZE]] for %[[RESOURCE1_SIZE]]])
// CHECK-SAME: (!stream.resource<transient>{%[[RESOURCE0_SIZE]]}, !stream.resource<external>{%[[RESOURCE1_SIZE]]}) ->
// CHECK-SAME: (!stream.resource<external>{%[[RESOURCE1_SIZE]]}, %[[RESOURCE1]]{%[[RESOURCE1_SIZE]]})
%results:2 = stream.tensor.dispatch @ex::@entry(%resource0, %resource1) : (tensor<4x?xf32>{%tensor0_dim} in !stream.resource<transient>{%resource0_size}, tensor<?xi32>{%tensor1_dim} in !stream.resource<external>{%resource1_size}) -> (tensor<?xi32>{%tensor1_dim} in !stream.resource<external>{%resource1_size}, tensor<?xf32>{%tensor1_dim} in %resource1{%resource1_size})
%results:2 = stream.tensor.dispatch @ex::@entry(%resource0, %resource1) : (tensor<4x?xf32>{%tensor0_dim} in !stream.resource<transient>{%resource0_size}, tensor<?xi32>{%tensor1_dim} in !stream.resource<external>{%resource1_size}) -> (tensor<4x?xf32>{%tensor0_dim} in !stream.resource<external>{%resource1_size}, tensor<?xi32>{%tensor1_dim} in %resource1{%resource1_size})
// CHECK: util.return %[[RESULTS]]#0, %[[RESULTS]]#1
util.return %results#0, %results#1 : !stream.resource<external>, !stream.resource<external>
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,47 @@ util.func public @tensor_update_op(%arg0: !stream.resource<*>, %arg1: index, %ar

// -----

#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]>
hanhanW marked this conversation as resolved.
Show resolved Hide resolved

util.global private @device_a = #device_target_local_0_
stream.executable private @executable {
stream.executable.export public @dispatch
builtin.module {
func.func @dispatch(%arg0: !stream.binding, %arg1: index) {
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<16xf32, #encoding>>
return
}
}
}
util.func public @tensor_dispatch_with_tied_operands(%arg0: !stream.resource<external>, %arg1: index, %arg2: index) -> !stream.resource<*> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%0 = stream.async.transfer %arg0 : !stream.resource<external>{%arg2} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%arg2}
%1 = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @executable::@dispatch[%c1, %c2, %c3](%0, %c4) : (tensor<4x?xf32, #encoding>{%arg2} in !stream.resource<*>{%arg1}, index) -> tensor<4x?xf32, #encoding>{%arg2} in %0{%arg1}
util.return %1 : !stream.resource<*>
}
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.encoding<{{.+}} layouts = [#iree_encoding.specialized_encoding<123, tensor<4x?xf32>>]
// CHECK: #[[TARGET:.+]] = #hal.device.target
// CHECK: util.global private @[[$DEVICE:.+]] = #[[TARGET]]
// CHECK-LABEL: util.func public @tensor_dispatch_with_tied_operands
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK: stream.tensor.dispatch on(#hal.device.affinity<@[[$DEVICE]]>)
// CHECK-SAME: tensor<4x?xf32, #[[$ENCODING]]>{%[[ARG2]]}
// CHECK-SAME: tensor<4x?xf32, #[[$ENCODING]]>{%[[ARG2]]}

// -----

#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {encoding = #iree_encoding.unspecialized_encoding<123>}>
#map = affine_map<(d0) -> (d0)>
#map0 = affine_map<(m, n, k) -> (m, k)>
Expand Down
Loading