-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Annotate kernel params #289
base: main
Are you sure you want to change the base?
Conversation
1057012
to
20604de
Compare
src/enzyme_ad/jax/Passes/Passes.td
Outdated
2. Block dimension | ||
3. Block index. | ||
Additionally, set all the following attributes for all the kernel pointers: | ||
1. align 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically I looked it up and we can assume 128 byte aligned: https://github.com/openxla/xla/blob/109a8ff382f2d174ee13504fb4f140a745cb25a2/xla/service/gpu/gpu_constants.h#L41
3caf315
to
0dba906
Compare
auto callee = symTable.lookup<FunctionOpInterface>(symbolName); | ||
if (!callee) | ||
return WalkResult::advance(); | ||
MLIRContext *ctx = callee->getContext(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something I realize here. this (and the above) is wrong. A callee could have multiple callers. You need to conservatively min the dereferencable size over all callers (and same for the range above, though that's a max)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind taking another look? thanks!
e0733ea
to
fb7f715
Compare
if (matchPattern(maybeCst, m_ConstantInt(&intValue))) | ||
target->setAttr("range", LLVM::ConstantRangeAttr::get( | ||
ctx, 32, 0, intValue.getSExtValue())); | ||
if (matchPattern(maybeCst, m_ConstantInt(&intValue))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the best way to do things.
If, for example, there is already range metadata (that is weaker), we won't do the update.
I think the thing to do here is to
- get all kernelcall ops
- then get all called functions
- for each called function
a) loop over all callers, getting the max over all ranges/etc
b) set the derived values/aliasing/etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not understand your previous comment then. What do you mean by "weaker range"? Can you provide an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah supposed that the existing kernel code before this pass has a range of 1-100000, but all the callers have a range 1-100. We would be taking the max over the previous setting and all the actual callers
No description provided.