-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FP32 and Bias to fulfill the functionalities required by `torch.n…
…n.attention.SDPBackend.EFFICIENT_ATTENTION` (#22) This PR includes the following major changes 1. Add Bias support in the Triton kernel, for both forward and backward directions 2. Add `fp32` datatype support, and the corresponding tuning database information 3. Fix "argument list too long" error during linking 4. Improved `table_tool.py` to partially dump/load `.csv` file, allowing database merging (*) 5. Refactor the UT to use PyTorch's method to estimate ATOL/RTOL Known limitations: 1. Gradient of Bias assumes real Rank 4 tensor (`.expand()`-ed ones are unlikely to work). No checking is performed on this requisite and failure may be silent. Bias itself is not affected since its read-only. 2. `test_forward.py` is still using the old method to estimate ATOL/RTOL * Examples of using `table_tool.py` to merge databases ``` DB=v2python/rules/tuning_database.sqlite3 python -m v2python.table_tool -k '' --action dumpcsv \ -f $DB --table_name 'FLASH$attn_fwd' \ --table_file 'attn_fwd.fp32mi300.csv' \ --select_where 'inputs$Q_dtype = "torch.float32"' git checkout another_branch -- $DB python -m v2python.table_tool -k '' --action loadcsv \ -f $DB --table_name 'FLASH$attn_fwd' \ --table_file attn_fwd.fp32mi300.csv \ --ignore_id ``` Note: --ignored_id does not support cases that 'id' is not the first column of the CSV file, for simplicity.
- Loading branch information
1 parent
71bd17f
commit 00ccbf3
Showing
18 changed files
with
380 additions
and
357 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../tritonsrc/_common_test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.