Skip to content

Commit

Permalink
Fix for metal tanh. (huggingface#2475)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Sep 13, 2024
1 parent b60faeb commit c09afc2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) {
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
Expand Down Expand Up @@ -154,7 +154,6 @@ UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
Expand All @@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)

// tanh may create NaN on large values, e.g. 45 rather than outputing 1.
// This has been an issue for the encodec example.
UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided);
UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);

#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
Expand All @@ -185,13 +189,14 @@ BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
BFLOAT_UNARY_OP(sigmoid)

UNARY(id, bfloat, copy_bf16, copy_bf16_strided)

UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);

COPY2D(copy2d_bf16, bfloat)
#endif

0 comments on commit c09afc2

Please sign in to comment.