Skip to content

Commit 2bca384

Browse files
authored
BUGFIX: layernorm mean, rstd
mean and rstd were allocated on default "cuda" device
1 parent da2626b commit 2bca384

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

mamba_ssm/ops/triton/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ def _layer_norm_fwd(
143143
assert residual_out.stride(-1) == 1
144144
else:
145145
residual_out = None
146-
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
147-
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
146+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
147+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
148148
# Less than 64KB per feature: enqueue fused kernel
149149
MAX_FUSED_SIZE = 65536 // x.element_size()
150150
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))

0 commit comments

Comments
 (0)