Skip to content

Commit 2569fe1

Browse files
committed
Add nextafter intrinsic
1 parent 1b811cb commit 2569fe1

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/device/intrinsics/math.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,11 @@ end
274274
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
275275
@device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x)
276276

277+
@static if Metal.is_macos(v"14")
278+
@device_function nextafter(x::Float32, y::Float32) = ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
279+
@device_function nextafter(x::Float16, y::Float16) = ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y)
280+
end
281+
277282
# hypot without use of double
278283
#
279284
# taken from Cosmopolitan Libc
@@ -418,7 +423,7 @@ end
418423
j = fma(1.442695f0, a, 12582912.0f0)
419424
j = j - 12582912.0f0
420425
i = unsafe_trunc(Int32, j)
421-
f = fma(j, -6.93145752f-1, a) # log_2_hi
426+
f = fma(j, -6.93145752f-1, a) # log_2_hi
422427
f = fma(j, -1.42860677f-6, f) # log_2_lo
423428

424429
# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]

test/device/intrinsics.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,27 @@ end
164164
vecF = Array(SpecialFunctions.erfinv.(bufferF))
165165
@test vecF SpecialFunctions.erfinv.(f)
166166

167-
f = collect(LinRange(nextfloat(-88f0), 88f0, 100))
168-
bufferF = MtlArray(f)
169-
vecF = Array(expm1.(bufferF))
170-
@test vecF expm1.(f)
167+
g = collect(LinRange(nextfloat(-88f0), 88f0, 100))
168+
bufferG = MtlArray(g)
169+
vecG = Array(expm1.(bufferG))
170+
@test vecG expm1.(g)
171+
172+
if Metal.is_macos(v"14")
173+
function nextafter_test(X, y)
174+
idx = thread_position_in_grid_1d()
175+
X[idx] = Metal.nextafter(X[idx], y)
176+
return nothing
177+
end
178+
h = rand(Float32,1)
179+
bufferH = MtlArray(h)
180+
@metal nextafter_test(bufferH,typemax(Float32))
181+
synchronize()
182+
@test Array(bufferH) nextfloat.(h)
183+
184+
@metal nextafter_test(bufferH,typemin(Float32))
185+
synchronize()
186+
@test Array(bufferH) h
187+
end
171188
end
172189

173190
############################################################################################

0 commit comments

Comments
 (0)