|
82 | 82 | @device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x)
|
83 | 83 | @device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x)
|
84 | 84 |
|
85 |
| -# TODO: enable once PTX > 7.0 is supported |
86 |
| -# @device_override Base.tanh(x::Float16) = @asmcall("tanh.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) |
87 |
| - |
| 85 | +@device_override function Base.tanh(x::Float16) |
| 86 | + if compute_capability() >= sv"7.5" |
| 87 | + @asmcall("tanh.approx.f16 \$0, \$1;", "=r,r", Float16, Tuple{Float16}, x) |
| 88 | + else |
| 89 | + Float16(tanh(Float32(x))) |
| 90 | + end |
| 91 | +end |
88 | 92 |
|
89 | 93 | ## inverse hyperbolic
|
90 | 94 |
|
|
197 | 201 |
|
198 | 202 | return r
|
199 | 203 | end
|
200 |
| -@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) |
| 204 | +@device_override FastMath.exp2_fast(x::Float64) = exp2(x) |
| 205 | +@device_override FastMath.exp2_fast(x::Float32) = |
| 206 | + @asmcall("ex2.approx.f32 \$0, \$1;", "=r,r", Float32, Tuple{Float32}, x) |
| 207 | +@device_override function FastMath.exp2_fast(x::Float16) |
| 208 | + if compute_capability() >= sv"7.5" |
| 209 | + ccall("llvm.nvvm.ex2.approx.f16", llvmcall, Float16, (Float16,), x) |
| 210 | + else |
| 211 | + exp2(x) |
| 212 | + end |
| 213 | +end |
201 | 214 |
|
202 | 215 | @device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
|
203 | 216 | @device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)
|
|
0 commit comments