Skip to content

Commit b8ab3b6

Browse files
Float intrinsics fixes & test improvements (#531)
* Use `Base.tanpi` in intrinsics `tanpi` is in Julia since 1.10 so allsupported versions have it * Test more intrinsics and fix `min`/`max` Also clean up the different tests * [NFC] Add commented out intrinsics to test once added * `clamp` & `sign` * 2-arg atan * 3-arg max and min
1 parent ca092c8 commit b8ab3b6

File tree

2 files changed

+265
-53
lines changed

2 files changed

+265
-53
lines changed

src/device/intrinsics/math.jl

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ using Base.Math: throw_complex_domainerror
88
# - add support for vector types
99
# - consider emitting LLVM intrinsics and lowering those in the back-end
1010

11+
### Common Intrinsics
12+
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
13+
@device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
14+
@device_override Base.clamp(x::Float16, minval::Float16, maxval::Float16) = ccall("extern air.clamp.f16", llvmcall, Float16, (Float16, Float16, Float16), x, minval, maxval)
15+
16+
@device_override Base.sign(x::Float32) = ccall("extern air.sign.f32", llvmcall, Cfloat, (Cfloat,), x)
17+
@device_override Base.sign(x::Float16) = ccall("extern air.sign.f16", llvmcall, Float16, (Float16,), x)
18+
1119
### Floating Point Intrinsics
1220

1321
## Metal only supports single and half-precision floating-point types (and their vector counterparts)
@@ -17,13 +25,21 @@ using Base.Math: throw_complex_domainerror
1725
@device_override Base.abs(x::Float32) = ccall("extern air.fabs.f32", llvmcall, Cfloat, (Cfloat,), x)
1826
@device_override Base.abs(x::Float16) = ccall("extern air.fabs.f16", llvmcall, Float16, (Float16,), x)
1927

20-
@device_override FastMath.min_fast(x::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
21-
@device_override Base.min(x::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
22-
@device_override Base.min(x::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16,), x)
28+
@device_override FastMath.min_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
29+
@device_override Base.min(x::Float32, y::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
30+
@device_override Base.min(x::Float16, y::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16, Float16), x, y)
31+
32+
@device_override FastMath.min_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
33+
@device_override Base.min(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
34+
@device_override Base.min(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmin3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)
2335

24-
@device_override FastMath.max_fast(x::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
25-
@device_override Base.max(x::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
26-
@device_override Base.max(x::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16,), x)
36+
@device_override FastMath.max_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
37+
@device_override Base.max(x::Float32, y::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
38+
@device_override Base.max(x::Float16, y::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16, Float16), x, y)
39+
40+
@device_override FastMath.max_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
41+
@device_override Base.max(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
42+
@device_override Base.max(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmax3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)
2743

2844
@device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x)
2945
@device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x)
@@ -45,6 +61,10 @@ using Base.Math: throw_complex_domainerror
4561
@device_override Base.atan(x::Float32) = ccall("extern air.atan.f32", llvmcall, Cfloat, (Cfloat,), x)
4662
@device_override Base.atan(x::Float16) = ccall("extern air.atan.f16", llvmcall, Float16, (Float16,), x)
4763

64+
@device_override FastMath.atan_fast(x::Float32, y::Float32) = ccall("extern air.fast_atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
65+
@device_override Base.atan(x::Float32, y::Float32) = ccall("extern air.atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
66+
@device_override Base.atan(x::Float16, y::Float16) = ccall("extern air.atan2.f16", llvmcall, Float16, (Float16, Float16), x, y)
67+
4868
@device_override FastMath.atanh_fast(x::Float32) = ccall("extern air.fast_atanh.f32", llvmcall, Cfloat, (Cfloat,), x)
4969
@device_override Base.atanh(x::Float32) = ccall("extern air.atanh.f32", llvmcall, Cfloat, (Cfloat,), x)
5070
@device_override Base.atanh(x::Float16) = ccall("extern air.atanh.f16", llvmcall, Float16, (Float16,), x)
@@ -240,6 +260,7 @@ end
240260
s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
241261
(s, c[])
242262
end
263+
# XXX: Broken
243264
@device_override function Base.sincos(x::Float16)
244265
c = Ref{Float16}()
245266
s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)
@@ -267,8 +288,8 @@ end
267288
@device_override Base.tanh(x::Float16) = ccall("extern air.tanh.f16", llvmcall, Float16, (Float16,), x)
268289

269290
@device_function tanpi_fast(x::Float32) = ccall("extern air.fast_tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
270-
@device_function tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
271-
@device_function tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x)
291+
@device_override Base.tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
292+
@device_override Base.tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x)
272293

273294
@device_function trunc_fast(x::Float32) = ccall("extern air.fast_trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
274295
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
@@ -418,7 +439,7 @@ end
418439
j = fma(1.442695f0, a, 12582912.0f0)
419440
j = j - 12582912.0f0
420441
i = unsafe_trunc(Int32, j)
421-
f = fma(j, -6.93145752f-1, a) # log_2_hi
442+
f = fma(j, -6.93145752f-1, a) # log_2_hi
422443
f = fma(j, -1.42860677f-6, f) # log_2_lo
423444

424445
# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]

test/device/intrinsics.jl

Lines changed: 235 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using SpecialFunctions
21
using Metal: metal_support
2+
using Random
3+
using SpecialFunctions
34

45
@testset "arguments" begin
56
@on_device dispatch_quadgroups_per_threadgroup()
@@ -103,71 +104,261 @@ end
103104

104105
############################################################################################
105106

107+
MATH_INTR_FUNCS_1_ARG = [
108+
# Common functions
109+
# saturate, # T saturate(T x) Clamp between 0.0 and 1.0
110+
sign, # T sign(T x) returns 0.0 if x is NaN
111+
112+
# float math
113+
acos, # T acos(T x)
114+
asin, # T asin(T x)
115+
asinh, # T asinh(T x)
116+
atan, # T atan(T x)
117+
atanh, # T atanh(T x)
118+
ceil, # T ceil(T x)
119+
cos, # T cos(T x)
120+
cosh, # T cosh(T x)
121+
cospi, # T cospi(T x)
122+
exp, # T exp(T x)
123+
exp2, # T exp2(T x)
124+
exp10, # T exp10(T x)
125+
abs, #T [f]abs(T x)
126+
floor, # T floor(T x)
127+
Metal.fract, # T fract(T x)
128+
# ilogb, # Ti ilogb(T x)
129+
log, # T log(T x)
130+
log2, # T log2(T x)
131+
log10, # T log10(T x)
132+
# Metal.rint, # T rint(T x) # TODO: Add test. Not sure what the behaviour actually is
133+
round, # T round(T x)
134+
Metal.rsqrt, # T rsqrt(T x)
135+
sin, # T sin(T x)
136+
sinh, # T sinh(T x)
137+
sinpi, # T sinpi(T x)
138+
sqrt, # sqrt(T x)
139+
tan, # T tan(T x)
140+
tanh, # T tanh(T x)
141+
tanpi, # T tanpi(T x)
142+
trunc, # T trunc(T x)
143+
]
144+
Metal.rsqrt(x::Float16) = 1 / sqrt(x)
145+
Metal.rsqrt(x::Float32) = 1 / sqrt(x)
146+
Metal.fract(x::Float16) = mod(x, 1)
147+
Metal.fract(x::Float32) = mod(x, 1)
148+
149+
MATH_INTR_FUNCS_2_ARG = [
150+
# Common function
151+
# step, # T step(T edge, T x) Returns 0.0 if x < edge, otherwise it returns 1.0
152+
153+
# float math
154+
atan, # T atan2(T x, T y) Compute arc tangent of y over x.
155+
# fdim, # T fdim(T x, T y)
156+
max, # T [f]max(T x, T y)
157+
min, # T [f]min(T x, T y)
158+
# fmod, # T fmod(T x, T y)
159+
# frexp, # T frexp(T x, Ti &exponent)
160+
# ldexp, # T ldexp(T x, Ti k)
161+
# modf, # T modf(T x, T &intval)
162+
# nextafter, # T nextafter(T x, T y) # Metal 3.1+
163+
# sincos,
164+
hypot, # NOT MSL but tested the same
165+
]
166+
167+
MATH_INTR_FUNCS_3_ARG = [
168+
# Common functions
169+
# mix, # T mix(T x, T y, T a) # x+(y-x)*a
170+
# smoothstep, # T smoothstep(T edge0, T edge1, T x)
171+
fma, # T fma(T a, T b, T c)
172+
max, # T max3(T x, T y, T z)
173+
# median3, # T median3(T x, T y, T z)
174+
min, # T min3(T x, T y, T z)
175+
]
176+
106177
@testset "math" begin
107-
a = ones(Float32,1)
108-
a .* Float32(3.14)
109-
bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
110-
vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1)
178+
# 1-arg functions
179+
@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
180+
cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
181+
rand(T, 4)
182+
else
183+
T[0.0, -0.0, rand(T), -rand(T)]
184+
end
185+
186+
mtlarr = MtlArray(cpuarr)
187+
188+
mtlout = fill!(similar(mtlarr), 0)
111189

112-
function intr_test(arr)
190+
function kernel(res, arr)
113191
idx = thread_position_in_grid_1d()
114-
arr[idx] = cos(arr[idx])
192+
res[idx] = fun(arr[idx])
115193
return nothing
116194
end
117-
@metal intr_test(bufferA)
118-
synchronize()
119-
@test vecA cos.(a)
195+
Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
196+
@eval @test Array($mtlout) $fun.($cpuarr)
197+
end
198+
# 2-arg functions
199+
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
200+
N = 4
201+
arr1 = randn(T, N)
202+
arr2 = randn(T, N)
203+
mtlarr1 = MtlArray(arr1)
204+
mtlarr2 = MtlArray(arr2)
205+
206+
mtlout = fill!(similar(mtlarr1), 0)
120207

121-
function intr_test2(arr)
208+
function kernel(res, x, y)
122209
idx = thread_position_in_grid_1d()
123-
arr[idx] = Metal.rsqrt(arr[idx])
210+
res[idx] = fun(x[idx], y[idx])
124211
return nothing
125212
end
126-
@metal intr_test2(bufferA)
127-
synchronize()
213+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
214+
@eval @test Array($mtlout) $fun.($arr1, $arr2)
215+
end
216+
# 3-arg functions
217+
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
218+
N = 4
219+
arr1 = randn(T, N)
220+
arr2 = randn(T, N)
221+
arr3 = randn(T, N)
128222

129-
bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
130-
vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1)
223+
mtlarr1 = MtlArray(arr1)
224+
mtlarr2 = MtlArray(arr2)
225+
mtlarr3 = MtlArray(arr3)
131226

132-
function intr_test3(arr_sin, arr_cos)
227+
mtlout = fill!(similar(mtlarr1), 0)
228+
229+
function kernel(res, x, y, z)
133230
idx = thread_position_in_grid_1d()
134-
s, c = sincos(arr_cos[idx])
135-
arr_sin[idx] = s
136-
arr_cos[idx] = c
231+
res[idx] = fun(x[idx], y[idx], z[idx])
137232
return nothing
138233
end
234+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
235+
@eval @test Array($mtlout) $fun.($arr1, $arr2, $arr3)
236+
end
237+
end
139238

140-
@metal intr_test3(bufferA, bufferB)
141-
synchronize()
142-
@test vecA sin.(a)
143-
@test vecB cos.(a)
239+
@testset "unique math" begin
240+
@testset "$T" for T in (Float32, Float16)
241+
let # acosh
242+
arr = T[0, rand(T, 3)...] .+ T(1)
243+
buffer = MtlArray(arr)
244+
vec = acosh.(buffer)
245+
@test Array(vec) acosh.(arr)
246+
end
144247

145-
b = collect(LinRange(nextfloat(-1f0), 10f0, 20))
146-
bufferC = MtlArray(b)
147-
vecC = Array(log1p.(bufferC))
148-
@test vecC log1p.(b)
248+
let # sincos
249+
N = 4
250+
arr = rand(T, N)
251+
bufferA = MtlArray(arr)
252+
bufferB = MtlArray(arr)
253+
function intr_test3(arr_sin, arr_cos)
254+
idx = thread_position_in_grid_1d()
255+
sinres, cosres = sincos(arr_cos[idx])
256+
arr_sin[idx] = sinres
257+
arr_cos[idx] = cosres
258+
return nothing
259+
end
260+
# Broken with Float16
261+
if T == Float16
262+
@test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
263+
else
264+
Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
265+
@test Array(bufferA) sin.(arr)
266+
@test Array(bufferB) cos.(arr)
267+
end
268+
end
149269

270+
let # clamp
271+
N = 4
272+
in = randn(T, N)
273+
minval = fill(T(-0.6), N)
274+
maxval = fill(T(0.6), N)
150275

151-
d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
152-
bufferD = MtlArray(d)
153-
vecD = Array(SpecialFunctions.erf.(bufferD))
154-
@test vecD SpecialFunctions.erf.(d)
276+
mtlin = MtlArray(in)
277+
mtlminval = MtlArray(minval)
278+
mtlmaxval = MtlArray(maxval)
155279

280+
mtlout = fill!(similar(mtlin), 0)
281+
282+
function kernel(res, x, y, z)
283+
idx = thread_position_in_grid_1d()
284+
res[idx] = clamp(x[idx], y[idx], z[idx])
285+
return nothing
286+
end
287+
Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
288+
@test Array(mtlout) == clamp.(in, minval, maxval)
289+
end
156290

157-
e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
158-
bufferE = MtlArray(e)
159-
vecE = Array(SpecialFunctions.erfc.(bufferE))
160-
@test vecE SpecialFunctions.erfc.(e)
291+
let #pow
292+
N = 4
293+
arr1 = rand(T, N)
294+
arr2 = rand(T, N)
295+
mtlarr1 = MtlArray(arr1)
296+
mtlarr2 = MtlArray(arr2)
161297

162-
f = collect(LinRange(-1f0, 1f0, 20))
163-
bufferF = MtlArray(f)
164-
vecF = Array(SpecialFunctions.erfinv.(bufferF))
165-
@test vecF SpecialFunctions.erfinv.(f)
298+
mtlout = fill!(similar(mtlarr1), 0)
166299

167-
f = collect(LinRange(nextfloat(-88f0), 88f0, 100))
168-
bufferF = MtlArray(f)
169-
vecF = Array(expm1.(bufferF))
170-
@test vecF expm1.(f)
300+
function kernel(res, x, y)
301+
idx = thread_position_in_grid_1d()
302+
res[idx] = x[idx]^y[idx]
303+
return nothing
304+
end
305+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
306+
@test Array(mtlout) arr1 .^ arr2
307+
end
308+
309+
let #powr
310+
N = 4
311+
arr1 = rand(T, N)
312+
arr2 = rand(T, N)
313+
mtlarr1 = MtlArray(arr1)
314+
mtlarr2 = MtlArray(arr2)
315+
316+
mtlout = fill!(similar(mtlarr1), 0)
317+
318+
function kernel(res, x, y)
319+
idx = thread_position_in_grid_1d()
320+
res[idx] = Metal.powr(x[idx], y[idx])
321+
return nothing
322+
end
323+
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
324+
@test Array(mtlout) arr1 .^ arr2
325+
end
326+
327+
let # log1p
328+
arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
329+
buffer = MtlArray(arr)
330+
vec = Array(log1p.(buffer))
331+
@test vec log1p.(arr)
332+
end
333+
334+
let # erf
335+
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
336+
buffer = MtlArray(arr)
337+
vec = Array(SpecialFunctions.erf.(buffer))
338+
@test vec SpecialFunctions.erf.(arr)
339+
end
340+
341+
let # erfc
342+
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
343+
buffer = MtlArray(arr)
344+
vec = Array(SpecialFunctions.erfc.(buffer))
345+
@test vec SpecialFunctions.erfc.(arr)
346+
end
347+
348+
let # erfinv
349+
arr = collect(LinRange(-1.0f0, 1.0f0, 20))
350+
buffer = MtlArray(arr)
351+
vec = Array(SpecialFunctions.erfinv.(buffer))
352+
@test vec SpecialFunctions.erfinv.(arr)
353+
end
354+
355+
let # expm1
356+
arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
357+
buffer = MtlArray(arr)
358+
vec = Array(expm1.(buffer))
359+
@test vec expm1.(arr)
360+
end
361+
end
171362
end
172363

173364
############################################################################################

0 commit comments

Comments
 (0)