diff --git a/mamba_ssm/ops/triton/layernorm.py b/mamba_ssm/ops/triton/layernorm.py index 1babca4c..ba33ce1e 100644 --- a/mamba_ssm/ops/triton/layernorm.py +++ b/mamba_ssm/ops/triton/layernorm.py @@ -143,8 +143,8 @@ def _layer_norm_fwd( assert residual_out.stride(-1) == 1 else: residual_out = None - mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None - rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))