Skip to content

Commit 5bb3a41

Browse files
authored
[JAX] Add the missing 1HSS tests (NVIDIA#1052)
Add the missing 1HSS tests Signed-off-by: Reese Wang <[email protected]>
1 parent d74e65f commit 5bb3a41

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/jax/test_fused_attn.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def _check_configs(self):
295295
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
296296
pytest.skip("Unsupported inputs combination or device compute capability.")
297297

298-
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
298+
if (
299+
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
300+
and self.bias_shape != BiasShape.BIAS_1HSS
301+
):
299302
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
300303
pytest.skip(
301304
"B1SS, BHSS and 11SS bias shapes are only supported for "

0 commit comments

Comments
 (0)