Skip to content

Commit 1c15c77

Browse files
committed
Wrap and test some more Float16 intrinsics
1 parent f62af73 commit 1c15c77

File tree

2 files changed

+202
-11
lines changed

2 files changed

+202
-11
lines changed

src/device/intrinsics/math.jl

+177-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
using Base: FastMath
44

5-
65
## helpers
76

87
within(lower, upper) = (val) -> lower <= val <= upper
@@ -103,17 +102,98 @@ end
103102

104103
@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x)
105104
@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x)
105+
@device_override function Base.log(x::Float16)
106+
log_x = @asmcall("""{.reg.b32 f, C;
107+
.reg.b16 r,h;
108+
mov.b16 h,\$1;
109+
cvt.f32.f16 f,h;
110+
lg2.approx.ftz.f32 f,f;
111+
mov.b32 C, 0x3f317218U;
112+
mul.f32 f,f,C;
113+
cvt.rn.f16.f32 r,f;
114+
.reg.b16 spc, ulp, p;
115+
mov.b16 spc, 0X160DU;
116+
mov.b16 ulp, 0x9C00U;
117+
set.eq.f16.f16 p, h, spc;
118+
fma.rn.f16 r,p,ulp,r;
119+
mov.b16 spc, 0X3BFEU;
120+
mov.b16 ulp, 0x8010U;
121+
set.eq.f16.f16 p, h, spc;
122+
fma.rn.f16 r,p,ulp,r;
123+
mov.b16 spc, 0X3C0BU;
124+
mov.b16 ulp, 0x8080U;
125+
set.eq.f16.f16 p, h, spc;
126+
fma.rn.f16 r,p,ulp,r;
127+
mov.b16 spc, 0X6051U;
128+
mov.b16 ulp, 0x1C00U;
129+
set.eq.f16.f16 p, h, spc;
130+
fma.rn.f16 r,p,ulp,r;
131+
mov.b16 \$0,r;
132+
}""", "=h,h", Float16, Tuple{Float16}, x)
133+
return log_x
134+
end
135+
106136
@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x)
107137

108138
@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x)
109139
@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x)
140+
@device_override function Base.log10(x::Float16)
141+
log_x = @asmcall("""{.reg.b16 h, r;
142+
.reg.b32 f, C;
143+
mov.b16 h, \$1;
144+
cvt.f32.f16 f, h;
145+
lg2.approx.ftz.f32 f, f;
146+
mov.b32 C, 0x3E9A209BU;
147+
mul.f32 f,f,C;
148+
cvt.rn.f16.f32 r, f;
149+
.reg.b16 spc, ulp, p;
150+
mov.b16 spc, 0x338FU;
151+
mov.b16 ulp, 0x1000U;
152+
set.eq.f16.f16 p, h, spc;
153+
fma.rn.f16 r,p,ulp,r;
154+
mov.b16 spc, 0x33F8U;
155+
mov.b16 ulp, 0x9000U;
156+
set.eq.f16.f16 p, h, spc;
157+
fma.rn.f16 r,p,ulp,r;
158+
mov.b16 spc, 0x57E1U;
159+
mov.b16 ulp, 0x9800U;
160+
set.eq.f16.f16 p, h, spc;
161+
fma.rn.f16 r,p,ulp,r;
162+
mov.b16 spc, 0x719DU;
163+
mov.b16 ulp, 0x9C00U;
164+
set.eq.f16.f16 p, h, spc;
165+
fma.rn.f16 r,p,ulp,r;
166+
mov.b16 \$0, r;
167+
}""", "=h,h", Float16, Tuple{Float16}, x)
168+
return log_x
169+
end
110170
@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x)
111171

112172
@device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x)
113173
@device_override Base.log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x)
114174

115175
@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x)
116176
@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x)
177+
@device_override function Base.log2(x::Float16)
178+
log_x = @asmcall("""{.reg.b16 h, r;
179+
.reg.b32 f;
180+
mov.b16 h, \$1;
181+
cvt.f32.f16 f, h;
182+
lg2.approx.ftz.f32 f, f;
183+
cvt.rn.f16.f32 r, f;
184+
.reg.b16 spc, ulp, p;
185+
mov.b16 spc, 0xA2E2U;
186+
mov.b16 ulp, 0x8080U;
187+
set.eq.f16.f16 p, r, spc;
188+
fma.rn.f16 r,p,ulp,r;
189+
mov.b16 spc, 0xBF46U;
190+
mov.b16 ulp, 0x9400U;
191+
set.eq.f16.f16 p, r, spc;
192+
fma.rn.f16 r,p,ulp,r;
193+
mov.b16 \$0, r;
194+
}""", "=h,h", Float16, Tuple{Float16}, x)
195+
return log_x
196+
end
117197
@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x)
118198

119199
@device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +207,95 @@ end
127207

128208
@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x)
129209
@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x)
210+
@device_override function Base.exp(x::Float16)
211+
exp_x = @asmcall("""{
212+
.reg.b32 f, C, nZ;
213+
.reg.b16 h,r;
214+
mov.b16 h,\$1;
215+
cvt.f32.f16 f,h;
216+
mov.b32 C, 0x3fb8aa3bU;
217+
mov.b32 nZ, 0x80000000U;
218+
fma.rn.f32 f,f,C,nZ;
219+
ex2.approx.ftz.f32 f,f;
220+
cvt.rn.f16.f32 r,f;
221+
.reg.b16 spc, ulp, p;
222+
mov.b16 spc,0X1F79U;
223+
mov.b16 ulp,0x9400U;
224+
set.eq.f16.f16 p, h, spc;
225+
fma.rn.f16 r,p,ulp,r;
226+
mov.b16 spc,0X25CFU;
227+
mov.b16 ulp,0x9400U;
228+
set.eq.f16.f16 p, h, spc;
229+
fma.rn.f16 r,p,ulp,r;
230+
mov.b16 spc,0XC13BU;
231+
mov.b16 ulp,0x0400U;
232+
set.eq.f16.f16 p, h, spc;
233+
fma.rn.f16 r,p,ulp,r;
234+
mov.b16 spc,0XC1EFU;
235+
mov.b16 ulp,0x0200U;
236+
set.eq.f16.f16 p, h, spc;
237+
fma.rn.f16 r,p,ulp,r;
238+
mov.b16 \$0,r;
239+
}""", "=h,h", Float16, Tuple{Float16}, x)
240+
return exp_x
241+
end
130242
@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x)
131243

132244
@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x)
133245
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x)
246+
@device_override function Base.exp2(x::Float16)
247+
exp_x = @asmcall("""{.reg.b32 f, ULP;
248+
.reg.b16 r;
249+
mov.b16 r,\$1;
250+
cvt.f32.f16 f,r;
251+
ex2.approx.ftz.f32 f,f;
252+
mov.b32 ULP, 0x33800000U;
253+
fma.rn.f32 f,f,ULP,f;
254+
cvt.rn.f16.f32 r,f;
255+
mov.b16 \$0,r;
256+
}""", "=h,h", Float16, Tuple{Float16}, x)
257+
return exp_x
258+
end
134259
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
135-
# TODO: enable once PTX > 7.0 is supported
136-
# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137260

138261
@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
139262
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)
263+
@device_override function Base.exp10(x::Float16)
264+
265+
exp_x = @asmcall("""{.reg.b16 h,r;
266+
.reg.b32 f, C, nZ;
267+
mov.b16 h, \$1;
268+
cvt.f32.f16 f, h;
269+
mov.b32 C, 0x40549A78U;
270+
mov.b32 nZ, 0x80000000U;
271+
fma.rn.f32 f,f,C,nZ;
272+
ex2.approx.ftz.f32 f, f;
273+
cvt.rn.f16.f32 r, f;
274+
.reg.b16 spc, ulp, p;
275+
mov.b16 spc,0x34DEU;
276+
mov.b16 ulp,0x9800U;
277+
set.eq.f16.f16 p, h, spc;
278+
fma.rn.f16 r,p,ulp,r;
279+
mov.b16 spc,0x9766U;
280+
mov.b16 ulp,0x9000U;
281+
set.eq.f16.f16 p, h, spc;
282+
fma.rn.f16 r,p,ulp,r;
283+
mov.b16 spc,0x9972U;
284+
mov.b16 ulp,0x1000U;
285+
set.eq.f16.f16 p, h, spc;
286+
fma.rn.f16 r,p,ulp,r;
287+
mov.b16 spc,0xA5C4U;
288+
mov.b16 ulp,0x1000U;
289+
set.eq.f16.f16 p, h, spc;
290+
fma.rn.f16 r,p,ulp,r;
291+
mov.b16 spc,0xBF0AU;
292+
mov.b16 ulp,0x8100U;
293+
set.eq.f16.f16 p, h, spc;
294+
fma.rn.f16 r,p,ulp,r;
295+
mov.b16 \$0, r;
296+
}""", "=h,h", Float16, Tuple{Float16}, x)
297+
return exp_x
298+
end
140299
@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x)
141300

142301
@device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x)
@@ -204,6 +363,13 @@ end
204363

205364
@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0
206365
@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
366+
@device_override function Base.isnan(x::Float16)
367+
if compute_capability() >= sv"8.0"
368+
return (ccall("extern __nv_hisnan", llvmcall, Int32, (Float16,), x)) != 0
369+
else
370+
return isnan(Float32(x))
371+
end
372+
end
207373

208374
@device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x)
209375
@device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +389,20 @@ end
223389
@device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x)
224390
@device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f)
225391
@device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f)
226-
# TODO: enable once PTX > 7.0 is supported
227-
# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
392+
@device_override Base.abs(f::Float16) = Float16(abs(Float32(f)))
228393
@device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x)
229394

230395
## roots and powers
231396

232397
@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x)
233398
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x)
399+
@device_override function Base.sqrt(x::Float16)
400+
if compute_capability() >= sv"8.0"
401+
ccall("extern __nv_hsqrt", llvmcall, Float16, (Float16,), x)
402+
else
403+
return Float16(sqrt(Float32(x)))
404+
end
405+
end
234406
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x)
235407

236408
@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x)

test/core/device/intrinsics/math.jl

+25-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ using SpecialFunctions
22

33
@testset "math" begin
44
@testset "log10" begin
5-
@test testf(a->log10.(a), Float32[100])
5+
for T in (Float32, Float64)
6+
@test testf(a->log10.(a), T[100])
7+
end
68
end
79

810
@testset "pow" begin
@@ -12,28 +14,34 @@ using SpecialFunctions
1214
@test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1))
1315
end
1416
end
17+
18+
@testset "min/max" begin
19+
for T in (Float32, Float64)
20+
@test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1))
21+
@test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1))
22+
end
23+
end
1524

1625
@testset "isinf" begin
17-
for x in (Inf32, Inf, NaN32, NaN)
26+
for x in (Inf32, Inf, NaN16, NaN32, NaN)
1827
@test testf(x->isinf.(x), [x])
1928
end
2029
end
2130

2231
@testset "isnan" begin
23-
for x in (Inf32, Inf, NaN32, NaN)
32+
for x in (Inf32, Inf, NaN16, NaN32, NaN)
2433
@test testf(x->isnan.(x), [x])
2534
end
2635
end
2736

2837
for op in (exp, angle, exp2, exp10,)
2938
@testset "$op" begin
30-
for T in (Float16, Float32, Float64)
39+
for T in (Float32, Float64)
3140
@test testf(x->op.(x), rand(T, 1))
3241
@test testf(x->op.(x), -rand(T, 1))
3342
end
3443
end
3544
end
36-
3745
for op in (expm1,)
3846
@testset "$op" begin
3947
# FIXME: add expm1(::Float16) to Base
@@ -50,7 +58,6 @@ using SpecialFunctions
5058
@test testf(x->op.(x), rand(T, 1))
5159
@test testf(x->op.(x), -rand(T, 1))
5260
end
53-
5461
end
5562
end
5663
@testset "mod and rem" begin
@@ -97,6 +104,18 @@ using SpecialFunctions
97104
# JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load
98105
@test testf(x->exp.(x), [1e7im])
99106
end
107+
108+
@testset "Real - $op" for op in (exp, abs, abs2, exp10, log10)
109+
@testset "$T" for T in (Float16, Float32, Float64)
110+
@test testf(x->op.(x), rand(T, 1))
111+
end
112+
end
113+
114+
@testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10)
115+
@testset "$T" for T in (Float16, )
116+
@test testf(x->op.(x), rand(T, 1))
117+
end
118+
end
100119

101120
@testset "fastmath" begin
102121
# libdevice provides some fast math functions

0 commit comments

Comments
 (0)