-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
potentially incorrect transformation? #241
Comments
module {
func.func @main(%arg0: tensor<2x12x4xf32>) -> (tensor<2x12x12xf32>, tensor<2x12x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x12x4xf32>) -> tensor<4x12x2xf32>
%cst = stablehlo.constant dense<2.000000e+00> : tensor<12x12x2xf32>
%1 = stablehlo.transpose %0, dims = [1, 0, 2] : (tensor<4x12x2xf32>) -> tensor<12x4x2xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x2xf32>
%2 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<12x4x2xf32>) -> tensor<2x12x4xf32>
%3 = stablehlo.transpose %0, dims = [2, 0, 1] : (tensor<4x12x2xf32>) -> tensor<2x4x12xf32>
%4 = stablehlo.convert %2 : tensor<2x12x4xf32>
%5 = stablehlo.convert %3 : tensor<2x4x12xf32>
%6 = stablehlo.dot_general %4, %5, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x12x4xf32>, tensor<2x4x12xf32>) -> tensor<2x12x12xf32>
%7 = stablehlo.transpose %6, dims = [1, 2, 0] : (tensor<2x12x12xf32>) -> tensor<12x12x2xf32>
%8 = stablehlo.multiply %cst, %7 : tensor<12x12x2xf32>
%9 = stablehlo.transpose %8, dims = [2, 1, 0] : (tensor<12x12x2xf32>) -> tensor<2x12x12xf32>
%10 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<4x12x2xf32>) -> tensor<2x12x4xf32>
return %9, %10 : tensor<2x12x12xf32>, tensor<2x12x4xf32>
}
} This is enough to cause the crash |
I just updated per @jumerckx PR for batching. Does that resolve? .... actually no I suppose not since the latter case doesn't have a call. I think the issue here is that dot_general needs a custom batch interface impl like transpose has |
This failure seems to be very specific to dot_general followed by a multiply. If I replace the |
I will try to reduce this but opening an initial version for now
running the batch pass
The text was updated successfully, but these errors were encountered: