Skip to content

Commit

Permalink
fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Apr 26, 2024
1 parent f6bb7e2 commit bffcd97
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/ops/test_selective_scan_var_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ def test_selective_scan_variable_length(is_variable_B, is_variable_C, varBC_grou
u_ref = u.detach().clone().requires_grad_()
delta_ref = delta.detach().clone().requires_grad_()
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
start_indexes = torch.arange(0, seqlen, seqlen / seq_num, dtype=torch.int64).cuda()
cu_seqlens = torch.arange(0, seqlen, seqlen / seq_num, dtype=torch.int64).cuda()

out, *rest = selective_scan_fn(
u, delta, A, B, C, D, z=z,
delta_bias=delta_bias, delta_softplus=delta_softplus,
return_last_state=return_last_state, cu_seqlens=start_indexes
return_last_state=return_last_state, cu_seqlens=cu_seqlens
)
if return_last_state:
state = rest[0]
out_ref, *rest = selective_scan_ref(
u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
return_last_state=return_last_state, cu_seqlens=start_indexes
return_last_state=return_last_state, cu_seqlens=cu_seqlens
)
if return_last_state:
state_ref = rest[0]
Expand Down

0 comments on commit bffcd97

Please sign in to comment.