-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: no-cache attention in PyTorch workflow #3085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: no-cache attention in PyTorch workflow #3085
Conversation
049879f
to
bd371f8
Compare
/bot run |
PR_Github #498 [ run ] triggered by Bot |
PR_Github #498 [ run ] completed with state |
/bot run |
PR_Github #617 [ run ] triggered by Bot |
PR_Github #617 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reviewed internally, approved.
Just need to add README and pass CI
/bot run |
PR_Github #622 [ run ] triggered by Bot |
PR_Github #622 [ run ] completed with state |
5ab8aba
to
78b1917
Compare
/bot run |
PR_Github #932 [ run ] triggered by Bot |
PR_Github #932 [ run ] completed with state |
/bot run |
PR_Github #977 [ run ] triggered by Bot |
PR_Github #977 [ run ] completed with state |
/bot run |
PR_Github #1036 [ run ] triggered by Bot |
PR_Github #1036 [ run ] completed with state |
/bot run --stage-list "L40S-1" |
PR_Github #1059 [ run ] triggered by Bot |
PR_Github #1059 [ run ] completed with state |
6ac8df5
to
c5785a6
Compare
/bot run --disable-fail-fast |
PR_Github #1079 [ run ] triggered by Bot |
PR_Github #1079 [ run ] completed with state |
/bot run --stage-list "A30-CPP-1" |
PR_Github #1151 [ run ] completed with state |
/bot run --reuse-test |
PR_Github #1164 [ run ] triggered by Bot |
PR_Github #1164 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #1167 [ run ] triggered by Bot |
PR_Github #1167 [ run ] completed with state |
Signed-off-by: Qixiang Lin <[email protected]>
…model test fix: fix minor bugs after rebase Signed-off-by: Qixiang Lin <[email protected]>
refactor: update max_seq_len documentation and remove max_seq_len for decoder model contructor in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]>
…s and mask type, enhance test_attention_no_cache to support FULL and CAUSAL masks Signed-off-by: Qixiang Lin <[email protected]>
… add type assertion for no cache attention in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]>
…elated classes, update documentation for KV cache handling Signed-off-by: Qixiang Lin <[email protected]>
…handling and remove unused conversion function Signed-off-by: Qixiang Lin <[email protected]>
…ess with useKVCache method and simplify token per block assignment remove Debug code. Signed-off-by: Qixiang Lin <[email protected]>
Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata. Signed-off-by: Qixiang Lin <[email protected]>
…on operations Signed-off-by: Qixiang Lin <[email protected]>
… relevant metadata classes Updated the attention backend interface to include KVCacheParams and imported TrtllmAttentionMetadata and VanillaAttentionMetadata in model_engine.py for enhanced functionality. Signed-off-by: Qixiang Lin <[email protected]>
Signed-off-by: Qixiang Lin <[email protected]>
Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType Signed-off-by: Qixiang Lin <[email protected]>
Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types. Signed-off-by: Qixiang Lin <[email protected]>
0b9a699
to
c211717
Compare
/bot reuse-pipeline |
PR_Github #1182 [ reuse-pipeline ] triggered by Bot |
PR_Github #1182 [ reuse-pipeline ] completed with state |
* init trtllm attn no cache Signed-off-by: Qixiang Lin <[email protected]> * fix: fix the seq_len issue and attn metadata prepare for qwen reward model test fix: fix minor bugs after rebase Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove unnecessary debug logs and clean up commented code refactor: update max_seq_len documentation and remove max_seq_len for decoder model contructor in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]> * refactor: update calculate_ref_result function to accept tensor inputs and mask type, enhance test_attention_no_cache to support FULL and CAUSAL masks Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove unused BERT attention metadata conversion method and add type assertion for no cache attention in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove use_kv_cache parameter from attention function and related classes, update documentation for KV cache handling Signed-off-by: Qixiang Lin <[email protected]> * refactor: implement setAttentionMaskType method for better mask type handling and remove unused conversion function Signed-off-by: Qixiang Lin <[email protected]> * refactor: streamline KV cache handling by replacing direct member access with useKVCache method and simplify token per block assignment remove Debug code. Signed-off-by: Qixiang Lin <[email protected]> * refactor: Resolve comments for Python code Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata. Signed-off-by: Qixiang Lin <[email protected]> * docs: Add is_dummy_attention field to attention metadata for simulation operations Signed-off-by: Qixiang Lin <[email protected]> * refactor: add KVCacheParams to attention backend interface and import relevant metadata classes Updated the attention backend interface to include KVCacheParams and imported TrtllmAttentionMetadata and VanillaAttentionMetadata in model_engine.py for enhanced functionality. Signed-off-by: Qixiang Lin <[email protected]> * fix: fix rebase format issue Signed-off-by: Qixiang Lin <[email protected]> * fix: extend attention mask type handling in MHARunnerFixedParams Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType Signed-off-by: Qixiang Lin <[email protected]> * fix: enhance attention mask type handling in TllmGenFmhaRunnerParams Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types. Signed-off-by: Qixiang Lin <[email protected]> --------- Signed-off-by: Qixiang Lin <[email protected]> Signed-off-by: sarattha <[email protected]>
* init trtllm attn no cache Signed-off-by: Qixiang Lin <[email protected]> * fix: fix the seq_len issue and attn metadata prepare for qwen reward model test fix: fix minor bugs after rebase Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove unnecessary debug logs and clean up commented code refactor: update max_seq_len documentation and remove max_seq_len for decoder model contructor in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]> * refactor: update calculate_ref_result function to accept tensor inputs and mask type, enhance test_attention_no_cache to support FULL and CAUSAL masks Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove unused BERT attention metadata conversion method and add type assertion for no cache attention in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove use_kv_cache parameter from attention function and related classes, update documentation for KV cache handling Signed-off-by: Qixiang Lin <[email protected]> * refactor: implement setAttentionMaskType method for better mask type handling and remove unused conversion function Signed-off-by: Qixiang Lin <[email protected]> * refactor: streamline KV cache handling by replacing direct member access with useKVCache method and simplify token per block assignment remove Debug code. Signed-off-by: Qixiang Lin <[email protected]> * refactor: Resolve comments for Python code Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata. Signed-off-by: Qixiang Lin <[email protected]> * docs: Add is_dummy_attention field to attention metadata for simulation operations Signed-off-by: Qixiang Lin <[email protected]> * refactor: add KVCacheParams to attention backend interface and import relevant metadata classes Updated the attention backend interface to include KVCacheParams and imported TrtllmAttentionMetadata and VanillaAttentionMetadata in model_engine.py for enhanced functionality. Signed-off-by: Qixiang Lin <[email protected]> * fix: fix rebase format issue Signed-off-by: Qixiang Lin <[email protected]> * fix: extend attention mask type handling in MHARunnerFixedParams Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType Signed-off-by: Qixiang Lin <[email protected]> * fix: enhance attention mask type handling in TllmGenFmhaRunnerParams Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types. Signed-off-by: Qixiang Lin <[email protected]> --------- Signed-off-by: Qixiang Lin <[email protected]>
* init trtllm attn no cache Signed-off-by: Qixiang Lin <[email protected]> * fix: fix the seq_len issue and attn metadata prepare for qwen reward model test fix: fix minor bugs after rebase Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove unnecessary debug logs and clean up commented code refactor: update max_seq_len documentation and remove max_seq_len for decoder model contructor in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]> * refactor: update calculate_ref_result function to accept tensor inputs and mask type, enhance test_attention_no_cache to support FULL and CAUSAL masks Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove unused BERT attention metadata conversion method and add type assertion for no cache attention in PyTorchModelEngine Signed-off-by: Qixiang Lin <[email protected]> * refactor: remove use_kv_cache parameter from attention function and related classes, update documentation for KV cache handling Signed-off-by: Qixiang Lin <[email protected]> * refactor: implement setAttentionMaskType method for better mask type handling and remove unused conversion function Signed-off-by: Qixiang Lin <[email protected]> * refactor: streamline KV cache handling by replacing direct member access with useKVCache method and simplify token per block assignment remove Debug code. Signed-off-by: Qixiang Lin <[email protected]> * refactor: Resolve comments for Python code Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata. Signed-off-by: Qixiang Lin <[email protected]> * docs: Add is_dummy_attention field to attention metadata for simulation operations Signed-off-by: Qixiang Lin <[email protected]> * refactor: add KVCacheParams to attention backend interface and import relevant metadata classes Updated the attention backend interface to include KVCacheParams and imported TrtllmAttentionMetadata and VanillaAttentionMetadata in model_engine.py for enhanced functionality. Signed-off-by: Qixiang Lin <[email protected]> * fix: fix rebase format issue Signed-off-by: Qixiang Lin <[email protected]> * fix: extend attention mask type handling in MHARunnerFixedParams Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType Signed-off-by: Qixiang Lin <[email protected]> * fix: enhance attention mask type handling in TllmGenFmhaRunnerParams Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types. Signed-off-by: Qixiang Lin <[email protected]> --------- Signed-off-by: Qixiang Lin <[email protected]>
No description provided.