Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoweringStrategy] Use a more general method to fetch input dims and sizes #1090

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yzhang93
Copy link
Contributor

@yzhang93 yzhang93 commented Feb 8, 2025

There is a bug in the existing codes to get M/N/K size from a matmul-like op, i.e.,

const uint64_t M = initShape[0];
const uint64_t N = initShape[1];
const uint64_t K = lhsShape[1];

Apparently, K shouldn't be lhsShape[1] if it's a matmul-tranpose-a op.
It's hard to infer the shape if the input matmul-like ops are transposed and in linalg.generic form. Or even the input has a higher number of dimensions such as mmt4d ops.

In addition, the indexing maps of matmul-like linalg.generic ops can be transposed during dispatch generation by default because of TransposeGenericOpsPass.
Example:
parallel parallel reduction parallel parallel reduction will become
parallel parallel parallel parallel reduction reduction after dispatch generation.
So if we still put the pack size/tile size at the dim for the former indexing map, it will generate wrong size for tiling.

This PR uses the upstream method linalg::inferContractionDims to infer the dim indices of M/N/K for all contraction ops.

@yzhang93 yzhang93 force-pushed the refactor_get_shapes_dims branch from 4ff5e29 to bba961e Compare February 8, 2025 04:33
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! A few comments to address.

Comment on lines +111 to +115
SmallVector<int64_t> shapes = linalgOp.getStaticLoopRanges();
if (mDims.size() + nDims.size() + kDims.size() > shapes.size()) {
return linalgOp.emitOpError(
"the total of m/n/k dims is larger than the number of loops.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this check ?
This seems like something which the upstream utility should handle, if not already.

Therefore if at all we need to ensure this check, I think we can simply assert on this.


auto getSizesAt = [&shapes](const SmallVector<unsigned, 2> &idx) {
SmallVector<int64_t, 2> sizes;
for (auto i : idx) sizes.push_back(shapes[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto i : idx) sizes.push_back(shapes[i]);
for (unsigned i : idx) sizes.push_back(shapes[i]);

Comment on lines +437 to +438
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols,
uint32_t numLoops) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use linalgOp.getNumLoops(); within each function ?
In that case sending numLoops as a function argument can be dropped here and elsewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants