Skip to content

Commit

Permalink
Merge pull request #103 from gaxler/patch-1
Browse files Browse the repository at this point in the history
BUGFIX: layernorm mean, rstd
  • Loading branch information
tridao authored Jan 15, 2024
2 parents da2626b + 2bca384 commit 86a3a90
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mamba_ssm/ops/triton/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def _layer_norm_fwd(
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
Expand Down

0 comments on commit 86a3a90

Please sign in to comment.