Skip to content

Commit e3d369e

Browse files
Fix Float16 sincos intrinsic (#533)
1 parent 924a130 commit e3d369e

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

src/device/intrinsics/math.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,18 +252,17 @@ end
252252

253253
@device_override function FastMath.sincos_fast(x::Float32)
254254
c = Ref{Cfloat}()
255-
s = ccall("extern air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
255+
s = @typed_ccall("air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
256256
(s, c[])
257257
end
258258
@device_override function Base.sincos(x::Float32)
259259
c = Ref{Cfloat}()
260-
s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
260+
s = @typed_ccall("air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
261261
(s, c[])
262262
end
263-
# XXX: Broken
264263
@device_override function Base.sincos(x::Float16)
265264
c = Ref{Float16}()
266-
s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)
265+
s = @typed_ccall("air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)
267266
(s, c[])
268267
end
269268

test/device/intrinsics.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ MATH_INTR_FUNCS_2_ARG = [
160160
# ldexp, # T ldexp(T x, Ti k)
161161
# modf, # T modf(T x, T &intval)
162162
# nextafter, # T nextafter(T x, T y) # Metal 3.1+
163-
# sincos,
164163
hypot, # NOT MSL but tested the same
165164
]
166165

@@ -257,14 +256,10 @@ end
257256
arr_cos[idx] = cosres
258257
return nothing
259258
end
260-
# Broken with Float16
261-
if T == Float16
262-
@test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
263-
else
264-
Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
265-
@test Array(bufferA) sin.(arr)
266-
@test Array(bufferB) cos.(arr)
267-
end
259+
260+
Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
261+
@test Array(bufferA) sin.(arr)
262+
@test Array(bufferB) cos.(arr)
268263
end
269264

270265
let # clamp

0 commit comments

Comments
 (0)