diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index a82bfdbdd6..e3a18cfe91 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -56,7 +56,7 @@ template METAL_FUNC T gelu(T x) { T x_cube = x_sq * x; T alpha = x + static_cast(0.044715) * x_cube; T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); - return static_cast(0.5) * x * (static_cast(1.0) + T(tanh(beta))); + return static_cast(0.5) * x * (static_cast(1.0) + T(precise::tanh(beta))); } template METAL_FUNC T relu(T in){ if (in < 0) { @@ -154,7 +154,6 @@ UNARY_OP(floor) UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) -UNARY_OP(tanh) UNARY_OP(recip) UNARY_OP(relu) UNARY_OP(sign) @@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) +// tanh may create NaN on large values, e.g. 45 rather than outputing 1. +// This has been an issue for the encodec example. +UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided); +UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided); + #if __METAL_VERSION__ >= 220 UNARY(id, int64_t, copy_i64, copy_i64_strided) COPY2D(copy2d_i64, int64_t) @@ -185,7 +189,6 @@ BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) -BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(relu) BFLOAT_UNARY_OP(sign) @@ -193,5 +196,7 @@ BFLOAT_UNARY_OP(sigmoid) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) +UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided); + COPY2D(copy2d_bf16, bfloat) #endif