-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoweringStrategy] Use a more general method to fetch input dims and sizes #1090
base: main
Are you sure you want to change the base?
Conversation
4ff5e29
to
bba961e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! A few comments to address.
SmallVector<int64_t> shapes = linalgOp.getStaticLoopRanges(); | ||
if (mDims.size() + nDims.size() + kDims.size() > shapes.size()) { | ||
return linalgOp.emitOpError( | ||
"the total of m/n/k dims is larger than the number of loops."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this check ?
This seems like something which the upstream utility should handle, if not already.
Therefore if at all we need to ensure this check, I think we can simply assert
on this.
|
||
auto getSizesAt = [&shapes](const SmallVector<unsigned, 2> &idx) { | ||
SmallVector<int64_t, 2> sizes; | ||
for (auto i : idx) sizes.push_back(shapes[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto i : idx) sizes.push_back(shapes[i]); | |
for (unsigned i : idx) sizes.push_back(shapes[i]); |
AMDAIEDevice targetDevice, uint32_t numRows, uint32_t numCols, | ||
uint32_t numLoops) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use linalgOp.getNumLoops();
within each function ?
In that case sending numLoops
as a function argument can be dropped here and elsewhere.
There is a bug in the existing codes to get M/N/K size from a matmul-like op, i.e.,
Apparently, K shouldn't be
lhsShape[1]
if it's amatmul-tranpose-a
op.It's hard to infer the shape if the input matmul-like ops are transposed and in
linalg.generic
form. Or even the input has a higher number of dimensions such as mmt4d ops.In addition, the indexing maps of matmul-like
linalg.generic
ops can be transposed during dispatch generation by default because ofTransposeGenericOpsPass
.Example:
parallel parallel reduction parallel parallel reduction
will becomeparallel parallel parallel parallel reduction reduction
after dispatch generation.So if we still put the pack size/tile size at the dim for the former indexing map, it will generate wrong size for tiling.
This PR uses the upstream method
linalg::inferContractionDims
to infer the dim indices of M/N/K for all contraction ops.