Skip to content

Commit c211717

Browse files
qixiang-99symphonylyh
authored andcommitted
fix: enhance attention mask type handling in TllmGenFmhaRunnerParams
Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types. Signed-off-by: Qixiang Lin <[email protected]>
1 parent 882198d commit c211717

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

Diff for: cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h

+1
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ struct MHARunnerFixedParams
202202
case 2: // tensorrt_llm::kernels::AttentionMaskType::SLIDING_WINDOW_CAUSAL
203203
attentionMaskType = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
204204
break;
205+
// NOTE: For BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE context phase, CAUSAL mask is used
205206
case 3: // tensorrt_llm::kernels::AttentionMaskType::BIDIRECTIONAL
206207
attentionMaskType = ContextAttentionMaskType::CAUSAL;
207208
break;

Diff for: cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h

+20-2
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,26 @@ struct TllmGenFmhaRunnerParams
238238
// set the attention mask type
239239
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType)
240240
{
241-
TLLM_CHECK_WITH_INFO(maskType >= 0 && maskType <= 3, "Invalid mask type for TrtllmGenAttentionMaskType");
242-
mMaskType = static_cast<TrtllmGenAttentionMaskType>(maskType);
241+
// maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
242+
// convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
243+
switch (maskType)
244+
{
245+
case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
246+
mMaskType = TrtllmGenAttentionMaskType::Dense;
247+
break;
248+
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
249+
mMaskType = TrtllmGenAttentionMaskType::Causal;
250+
break;
251+
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL
252+
mMaskType = TrtllmGenAttentionMaskType::SlidingWindowCausal;
253+
break;
254+
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
255+
mMaskType = TrtllmGenAttentionMaskType::Custom;
256+
break;
257+
default:
258+
TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
259+
static_cast<int>(maskType));
260+
}
243261
return *this;
244262
}
245263
};

0 commit comments

Comments
 (0)