Skip to content

Commit 882198d

Browse files
qixiang-99symphonylyh
authored andcommitted
fix: extend attention mask type handling in MHARunnerFixedParams
Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType Signed-off-by: Qixiang Lin <[email protected]>
1 parent f4b4a35 commit 882198d

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ struct MHARunnerFixedParams
202202
case 2: // tensorrt_llm::kernels::AttentionMaskType::SLIDING_WINDOW_CAUSAL
203203
attentionMaskType = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
204204
break;
205+
case 3: // tensorrt_llm::kernels::AttentionMaskType::BIDIRECTIONAL
206+
attentionMaskType = ContextAttentionMaskType::CAUSAL;
207+
break;
208+
case 4: // tensorrt_llm::kernels::AttentionMaskType::BIDIRECTIONALGLM
209+
attentionMaskType = ContextAttentionMaskType::CAUSAL;
210+
break;
211+
case 5: // tensorrt_llm::kernels::AttentionMaskType::BLOCKSPARSE
212+
attentionMaskType = ContextAttentionMaskType::CAUSAL;
213+
break;
205214
case 6: // tensorrt_llm::kernels::AttentionMaskType::CUSTOM_MASK
206215
attentionMaskType = ContextAttentionMaskType::CUSTOM_MASK;
207216
break;

0 commit comments

Comments
 (0)