Skip to content

Commit 27cb1ad

Browse files
maleadtkshyatt
authored andcommitted
Try enabling some more intrinsics.
1 parent 6ab4409 commit 27cb1ad

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/device/intrinsics/math.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ end
8282
@device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x)
8383
@device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x)
8484

85-
# TODO: enable once PTX > 7.0 is supported
86-
# @device_override Base.tanh(x::Float16) = @asmcall("tanh.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
87-
85+
@device_override function Base.tanh(x::Float16)
86+
if compute_capability() >= sv"7.5"
87+
@asmcall("tanh.approx.f16 \$0, \$1;", "=r,r", Float16, Tuple{Float16}, x)
88+
else
89+
Float16(tanh(Float32(x)))
90+
end
91+
end
8892

8993
## inverse hyperbolic
9094

@@ -197,7 +201,16 @@ end
197201

198202
return r
199203
end
200-
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
204+
@device_override FastMath.exp2_fast(x::Float64) = exp2(x)
205+
@device_override FastMath.exp2_fast(x::Float32) =
206+
@asmcall("ex2.approx.f32 \$0, \$1;", "=r,r", Float32, Tuple{Float32}, x)
207+
@device_override function FastMath.exp2_fast(x::Float16)
208+
if compute_capability() >= sv"7.5"
209+
ccall("llvm.nvvm.ex2.approx.f16", llvmcall, Float16, (Float16,), x)
210+
else
211+
exp2(x)
212+
end
213+
end
201214

202215
@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
203216
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)

0 commit comments

Comments
 (0)