-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support tensor of indices as loop iter-arg in sequences of pointer ar…
…ithmetic (#180) This PR adds support for tensor of indices that are updated in each loop iteration while also being used in pointer arithmetic sequences. ## Approach Similarly to the pointer types, in the PtrAnalysis pre-pass, we prematurely generate the `tts.get_structured_state` ops for tensor of integers. The important note here is we do not need to know whether these ops will eventually be used in a pointer arithmetic sequence. Any values that are not used in a pointer arithmetic sequence will be removed later in the process. This approach can easily be extended to other kinds of values that might be used in pointer arithmetic sequences. At a high level, `tts.get_structured_state` can always be used to "wrap" a triton value. This op returns two kinds of values: the first value is always of the same type as the wrapped value, while the remaining values expose the important fields in `PtrState` that are necessary for codegen in scf.for. The first return value of `tts.get_structured_state` is always an SSA value of the same type as the original value; users of the original triton value will then use this first return value from `tts.get_structured_state` instead. With this approach, even if the original triton value ends up not being used in a pointer arithmetic sequence, it is very easy to revert the IR to the original form by simply deleting the `tts.get_structured_state` op and forwarding the original triton value to its users again. The other return values then expose the important fields in PtrState that are necessary to generate the code in loops (offsets and strides). Within a loop, for every wrapped triton value returned by a `tts.get_structured_state` op at index `i`, we can always get the corresponding offsets in each loop iteration at index `i + 1` and strides at index `i + 2`. ## Changes + Updated the pre-pass to insert `tts.get_structured_state` ops that wrap tensor of indices + With the introduction of tensor of indices in loops, we now have to manually visit the `tts.get_structured_state` ops to generate the ops for updating PtrState. We previously did not have to do this because triton pointers always have a `tt.addptr` at the end of each loop, right before yielding the values, which always triggers the process for generating the state-update ops + Logic for determining whether a loop iter-arg should have its PtrState updated is improved. We do a BFS-like scan starting from the return values of `tts.get_structured_state` ops to determine if an iter-arg originates from a value that may need its PtrState populated + Preliminary support for mask sequences being updated in a loop; this is a bit of a hack and will need more robust implementation if these use cases appear more frequently. + Add tests for various scenarios
- Loading branch information
1 parent
5bd61a0
commit 177a624
Showing
16 changed files
with
880 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.