Skip to content

Commit 07d4241

Browse files
authored
update default sm num (#10586)
1 parent 65d2d3a commit 07d4241

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ops/csrc/fp8/deep_gemm/jit_kernels/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def auto_tuning_with_compilation(m, n, k, num_sms):
171171
return runtime, num_sms, smem_size
172172

173173

174-
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=112) -> None:
174+
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, num_sms=132) -> None:
175175
"""
176176
Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
177177
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.

ops/csrc/fp8/deep_gemm/jit_kernels/m_grouped_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def auto_tuning_with_compilation_grouped_gemm_contiguous(m, n, k, num_groups, nu
9898

9999

100100
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
101-
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor, num_sms=112
101+
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, m_indices: Tensor, num_sms=132
102102
) -> None:
103103
"""
104104
Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
@@ -215,7 +215,7 @@ def auto_tuning_with_compilation_grouped_gemm_masked(m, expected_m, n, k, num_gr
215215

216216

217217
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(
218-
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int, num_sms=112
218+
lhs: Tuple[Tensor, Tensor], rhs: Tuple[Tensor, Tensor], out: Tensor, masked_m: Tensor, expected_m: int, num_sms=132
219219
) -> None:
220220
"""
221221
Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.

0 commit comments

Comments
 (0)