-
Notifications
You must be signed in to change notification settings - Fork 656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Stream] Fix a getBindingLayoutAttrs bug in SpecializeEncoding pass. #19953
base: main
Are you sure you want to change the base?
[Stream] Fix a getBindingLayoutAttrs bug in SpecializeEncoding pass. #19953
Conversation
Previously, it did no consider the case that results have tied operands. In this context, we have to skip the result encodings. Because the binding is shared between the tied operand and the result. The function does not have the argument (i.e., duplicated binding) for the result. Signed-off-by: hanhanW <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is correct and may indicate that we need a different approach - bindings are just untyped pointers and do not care about encodings, and it's valid to load and store from the same binding with different encodings (or even different tensor types at different ranges). Basically, if this pass is assuming that a binding is used with a single type or encoding it's not going to work (and I think that's the issue you're hitting?).
compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
Outdated
Show resolved
Hide resolved
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> | ||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device | ||
#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2]> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use your test encodings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The testing encoding attribute is used, see above "#iree_encoding.unspecialized_encoding<123>". This is the basic encodings for gemm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no gemms here :) these tests should not contain any specific encodings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, then I think we need to either mark all the parameters optional, or move the cloneWithLayouts
to interface method and create a testing attribute for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do it in a follow-up. Do you have any preference? I think adding the testing attribute and the new interface method (e.g., cloneWithEncoding
) makes sense. It helps us organize the needs for encoding specialization, i.e., only the encoding attributes that implement the interface can be specialized.
compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir
Outdated
Show resolved
Hide resolved
Yes, I think it is what I hit. Basically, we have a scatter dispatch that shares the binding between an operand and the result. So the number of function arguments does not equal to builtin.module {
func.func @dispatch_scatter_Dx16x8x128xf16(%arg0: index, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: index, %arg5: index, %arg6: index) {
%c0 = arith.constant 0 : index
%0 = flow.dispatch.workload.ordinal %arg4, 1 : index
%1 = flow.dispatch.workload.ordinal %arg5, 2 : index
%2 = flow.dispatch.workload.ordinal %arg6, 3 : index
%3 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?xi64>>{%0}
%4 = stream.binding.subspan %arg2[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<?x16x8x128xf16>>{%1}
%5 = stream.binding.subspan %arg3[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<?x16x8x128xf16>>{%2}
%6 = flow.dispatch.workload.ordinal %arg0, 0 : index
%7 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [%0], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xi64>>{%0} -> tensor<?xi64>
%8 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0], sizes = [%1, 16, 8, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x16x8x128xf16>>{%1} -> tensor<?x16x8x128xf16>
%9 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0], sizes = [%2, 16, 8, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<?x16x8x128xf16>>{%2} -> tensor<?x16x8x128xf16>
%10 = tensor.empty(%6) : tensor<?xi32>
%11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%7 : tensor<?xi64>) outs(%10 : tensor<?xi32>) {
^bb0(%in: i64, %out: i32):
%13 = arith.trunci %in : i64 to i32
linalg.yield %13 : i32
} -> tensor<?xi32>
%12 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%8, %11 : tensor<?x16x8x128xf16>, tensor<?xi32>) outs(%9 : tensor<?x16x8x128xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x16x8x128xf16>
flow.dispatch.tensor.store %12, %5, offsets = [0, 0, 0, 0], sizes = [%2, 16, 8, 128], strides = [1, 1, 1, 1] : tensor<?x16x8x128xf16> -> !flow.dispatch.tensor<readwrite:tensor<?x16x8x128xf16>>{%2}
return
}
}
I think we do not have such case in practice at this moment. I agree that they can be the same binding with different encodings. There is a gap to support the case, because we are not able to track which encoding is used in which |
we could add a verifier to stream.tensor.dispatch to check that any tied result must have a matching encoding - then we catch it early when the op is lowered from flow |
Okay, let me update the verifier and the test case. |
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeEncodings.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
continue; | ||
} | ||
auto operandIndex = tiedOperand.value() - tiedOperandBase; | ||
if (operandEncodings[operandIndex] != resEncoding) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@benvanik I might do something wrong here. Do you mean to check the "type" or the encoding attribute?
Previously, it did not consider the case that results have tied operands. In the cases that have tied operands, we have to skip the result encodings. Because the binding is shared between the tied operand and the result. The function does not have the argument (i.e., duplicated binding) for the result.
In the revision, we tighten the definition of
stream.tensor.dispatch
op. In the past, it allows the type of result being different from the tied operand. Now we require any tied result must have a matching encoding. Because they share the same bindings, and it is hard to track which encoding is used in which binding today. Adding the check to the verifier allows us to catch it early when the op is lowered from flow.