We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d74e65f commit 5bb3a41Copy full SHA for 5bb3a41
tests/jax/test_fused_attn.py
@@ -295,7 +295,10 @@ def _check_configs(self):
295
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
296
pytest.skip("Unsupported inputs combination or device compute capability.")
297
298
- if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
+ if (
299
+ self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
300
+ and self.bias_shape != BiasShape.BIAS_1HSS
301
+ ):
302
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
303
pytest.skip(
304
"B1SS, BHSS and 11SS bias shapes are only supported for "
0 commit comments