diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index a113f84b9..8c7973639 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -2,12 +2,19 @@ using Base: FastMath using Base.Math: throw_complex_domainerror +import Core: Float16, Float32 # TODO: # - wrap all intrinsics from include/metal/metal_math # - add support for vector types # - consider emitting LLVM intrinsics and lowering those in the back-end +### Constants +# π and ℯ +for T in (:Float16,:Float32), R in (RoundUp, RoundDown), irr in (π, ℯ) + @eval @device_override $T(::typeof($irr), ::typeof($R)) = $@eval($T($irr,$R)) +end + ### Common Intrinsics @device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval) @device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval) diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl index 185f76eac..ac1d7e9e4 100644 --- a/test/device/intrinsics/math.jl +++ b/test/device/intrinsics/math.jl @@ -311,6 +311,21 @@ end ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal metal = v"3.0" nextafter_out_test())) @test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir) end + + # Borrowed from the Julia "Irrationals compared with Rationals and Floats" testset + @testset "Comparisons with $irr" for irr in (π, ℯ) + @eval function convert_test(res) + res[1] = $T($irr, RoundDown) < $irr + res[2] = $T($irr, RoundUp) > $irr + res[3] = !($T($irr, RoundDown) > $irr) + res[4] = !($T($irr, RoundUp) < $irr) + return nothing + end + + res = MtlArray(zeros(Bool, 4)) + Metal.@sync @metal convert_test(res) + @test all(Array(res)) + end end end