Skip to content

Commit f89e1ab

Browse files
authored
Add some more fastmath functions (#2047)
1 parent 29235f4 commit f89e1ab

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/device/intrinsics/math.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ end
131131

132132
@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x)
133133
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x)
134+
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
134135
# TODO: enable once PTX > 7.0 is supported
135136
# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
136137

@@ -221,6 +222,7 @@ end
221222

222223
@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x)
223224
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x)
225+
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x)
224226

225227
@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x)
226228
@device_function rsqrt(x::Float32) = ccall("extern __nv_rsqrtf", llvmcall, Cfloat, (Cfloat,), x)
@@ -306,6 +308,8 @@ end
306308

307309
@device_override FastMath.div_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_fdividef", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
308310

311+
@device_override Base.inv(x::Float32) = ccall("extern __nv_frcp_rn", llvmcall, Cfloat, (Cfloat,), x)
312+
@device_override FastMath.inv_fast(x::Union{Float32, Float64}) = @fastmath one(x) / x
309313

310314
## distributions
311315

0 commit comments

Comments
 (0)