Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to float intrinsics #531

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Improvements to float intrinsics #531

wants to merge 8 commits into from

Conversation

christiangnrd
Copy link
Contributor

@christiangnrd christiangnrd commented Jan 30, 2025

Replaces #529

Description of each commit in order:

  • tanpi has been in base Julia since 1.10, so switch from @device_function to @device_override.
  • Test all currently-defined float intrinsics and slightly refactor the old ones.
  • List the float intrinsics from the metal shading language in the tests.
  • Add intrinsics for clamp & sign
  • Add intrinsics for nextafter
  • Add intrinsics for 2-arg atan
  • Add intrinsics for 3-arg max & 'min'

Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 2f1d3b78..95eb9a6a 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -173,210 +173,210 @@ MATH_INTR_FUNCS_3_ARG = [
 ]
 
 @testset "math" begin
-# 1-arg functions
-@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
-    cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
-        rand(T, 4)
-    else
-        T[0.0, -0.0, rand(T), -rand(T)]
-    end
+    # 1-arg functions
+    @testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
+        cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
+            rand(T, 4)
+        else
+            T[0.0, -0.0, rand(T), -rand(T)]
+        end
 
-    mtlarr = MtlArray(cpuarr)
+        mtlarr = MtlArray(cpuarr)
 
-    mtlout = fill!(similar(mtlarr), 0)
+        mtlout = fill!(similar(mtlarr), 0)
 
-    function kernel(res, arr)
+        function kernel(res, arr)
         idx = thread_position_in_grid_1d()
-        res[idx] = fun(arr[idx])
+            res[idx] = fun(arr[idx])
         return nothing
     end
-    Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
-    @eval @test Array($mtlout) ≈ $fun.($cpuarr)
-end
-# 2-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
-    N = 4
-    arr1 = randn(T, N)
-    arr2 = randn(T, N)
-    mtlarr1 = MtlArray(arr1)
-    mtlarr2 = MtlArray(arr2)
+        Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
+        @eval @test Array($mtlout) ≈ $fun.($cpuarr)
+    end
+    # 2-arg functions
+    @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
+        N = 4
+        arr1 = randn(T, N)
+        arr2 = randn(T, N)
+        mtlarr1 = MtlArray(arr1)
+        mtlarr2 = MtlArray(arr2)
 
-    mtlout = fill!(similar(mtlarr1), 0)
+        mtlout = fill!(similar(mtlarr1), 0)
 
-    function kernel(res, x, y)
+        function kernel(res, x, y)
         idx = thread_position_in_grid_1d()
-        res[idx] = fun(x[idx], y[idx])
+            res[idx] = fun(x[idx], y[idx])
         return nothing
     end
-    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
-end
-# 3-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
-    N = 4
-    arr1 = randn(T, N)
-    arr2 = randn(T, N)
-    arr3 = randn(T, N)
+        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+        @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
+    end
+    # 3-arg functions
+    @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
+        N = 4
+        arr1 = randn(T, N)
+        arr2 = randn(T, N)
+        arr3 = randn(T, N)
 
-    mtlarr1 = MtlArray(arr1)
-    mtlarr2 = MtlArray(arr2)
-    mtlarr3 = MtlArray(arr3)
+        mtlarr1 = MtlArray(arr1)
+        mtlarr2 = MtlArray(arr2)
+        mtlarr3 = MtlArray(arr3)
 
-    mtlout = fill!(similar(mtlarr1), 0)
+        mtlout = fill!(similar(mtlarr1), 0)
 
-    function kernel(res, x, y, z)
+        function kernel(res, x, y, z)
         idx = thread_position_in_grid_1d()
-        res[idx] = fun(x[idx], y[idx], z[idx])
+            res[idx] = fun(x[idx], y[idx], z[idx])
         return nothing
     end
-    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
-    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
-end
+        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
+        @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
+    end
 end
 
 @testset "unique math" begin
-@testset "$T" for T in (Float32, Float16)
-    let # acosh
-        arr = T[0, rand(T, 3)...] .+ T(1)
-        buffer = MtlArray(arr)
-        vec = acosh.(buffer)
-        @test Array(vec) ≈ acosh.(arr)
-    end
-
-    let # sincos
-        N = 4
-        arr = rand(T, N)
-        bufferA = MtlArray(arr)
-        bufferB = MtlArray(arr)
-        function intr_test3(arr_sin, arr_cos)
-            idx = thread_position_in_grid_1d()
-            sinres, cosres = sincos(arr_cos[idx])
-            arr_sin[idx] = sinres
-            arr_cos[idx] = cosres
-            return nothing
+    @testset "$T" for T in (Float32, Float16)
+        let # acosh
+            arr = T[0, rand(T, 3)...] .+ T(1)
+            buffer = MtlArray(arr)
+            vec = acosh.(buffer)
+            @test Array(vec) ≈ acosh.(arr)
         end
-        # Broken with Float16
-        if T == Float16
-            @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
-        else
-            Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
-            @test Array(bufferA) ≈ sin.(arr)
-            @test Array(bufferB) ≈ cos.(arr)
+
+        let # sincos
+            N = 4
+            arr = rand(T, N)
+            bufferA = MtlArray(arr)
+            bufferB = MtlArray(arr)
+            function intr_test3(arr_sin, arr_cos)
+                idx = thread_position_in_grid_1d()
+                sinres, cosres = sincos(arr_cos[idx])
+                arr_sin[idx] = sinres
+                arr_cos[idx] = cosres
+                return nothing
+            end
+            # Broken with Float16
+            if T == Float16
+                @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+            else
+                Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+                @test Array(bufferA) ≈ sin.(arr)
+                @test Array(bufferB) ≈ cos.(arr)
+            end
         end
-    end
 
-    let # clamp
-        N = 4
-        in = randn(T, N)
-        minval = fill(T(-0.6), N)
-        maxval = fill(T(0.6), N)
+        let # clamp
+            N = 4
+            in = randn(T, N)
+            minval = fill(T(-0.6), N)
+            maxval = fill(T(0.6), N)
 
-        mtlin = MtlArray(in)
-        mtlminval = MtlArray(minval)
-        mtlmaxval = MtlArray(maxval)
+            mtlin = MtlArray(in)
+            mtlminval = MtlArray(minval)
+            mtlmaxval = MtlArray(maxval)
 
-        mtlout = fill!(similar(mtlin), 0)
+            mtlout = fill!(similar(mtlin), 0)
 
-        function kernel(res, x, y, z)
-            idx = thread_position_in_grid_1d()
-            res[idx] = clamp(x[idx], y[idx], z[idx])
-            return nothing
+            function kernel(res, x, y, z)
+                idx = thread_position_in_grid_1d()
+                res[idx] = clamp(x[idx], y[idx], z[idx])
+                return nothing
+            end
+            Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
+            @test Array(mtlout) == clamp.(in, minval, maxval)
         end
-        Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
-        @test Array(mtlout) == clamp.(in, minval, maxval)
-    end
 
-    let #pow
-        N = 4
-        arr1 = rand(T, N)
-        arr2 = rand(T, N)
-        mtlarr1 = MtlArray(arr1)
-        mtlarr2 = MtlArray(arr2)
+        let #pow
+            N = 4
+            arr1 = rand(T, N)
+            arr2 = rand(T, N)
+            mtlarr1 = MtlArray(arr1)
+            mtlarr2 = MtlArray(arr2)
 
-        mtlout = fill!(similar(mtlarr1), 0)
+            mtlout = fill!(similar(mtlarr1), 0)
 
-        function kernel(res, x, y)
-            idx = thread_position_in_grid_1d()
-            res[idx] = x[idx]^y[idx]
-            return nothing
+            function kernel(res, x, y)
+                idx = thread_position_in_grid_1d()
+                res[idx] = x[idx]^y[idx]
+                return nothing
+            end
+            Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+            @test Array(mtlout) ≈ arr1 .^ arr2
         end
-        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-        @test Array(mtlout) ≈ arr1 .^ arr2
-    end
 
-    let #powr
-        N = 4
-        arr1 = rand(T, N)
-        arr2 = rand(T, N)
-        mtlarr1 = MtlArray(arr1)
-        mtlarr2 = MtlArray(arr2)
+        let #powr
+            N = 4
+            arr1 = rand(T, N)
+            arr2 = rand(T, N)
+            mtlarr1 = MtlArray(arr1)
+            mtlarr2 = MtlArray(arr2)
 
-        mtlout = fill!(similar(mtlarr1), 0)
+            mtlout = fill!(similar(mtlarr1), 0)
 
-        function kernel(res, x, y)
-            idx = thread_position_in_grid_1d()
-            res[idx] = Metal.powr(x[idx], y[idx])
-            return nothing
+            function kernel(res, x, y)
+                idx = thread_position_in_grid_1d()
+                res[idx] = Metal.powr(x[idx], y[idx])
+                return nothing
+            end
+            Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+            @test Array(mtlout) ≈ arr1 .^ arr2
         end
-        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-        @test Array(mtlout) ≈ arr1 .^ arr2
-    end
 
-    let # log1p
-        arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(log1p.(buffer))
-        @test vec ≈ log1p.(arr)
-    end
+        let # log1p
+            arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(log1p.(buffer))
+            @test vec ≈ log1p.(arr)
+        end
 
-    let # erf
-        arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(SpecialFunctions.erf.(buffer))
-        @test vec ≈ SpecialFunctions.erf.(arr)
-    end
+        let # erf
+            arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(SpecialFunctions.erf.(buffer))
+            @test vec ≈ SpecialFunctions.erf.(arr)
+        end
 
-    let # erfc
-        arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(SpecialFunctions.erfc.(buffer))
-        @test vec ≈ SpecialFunctions.erfc.(arr)
-    end
+        let # erfc
+            arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(SpecialFunctions.erfc.(buffer))
+            @test vec ≈ SpecialFunctions.erfc.(arr)
+        end
 
-    let # erfinv
-        arr = collect(LinRange(-1.0f0, 1.0f0, 20))
-        buffer = MtlArray(arr)
-        vec = Array(SpecialFunctions.erfinv.(buffer))
-        @test vec ≈ SpecialFunctions.erfinv.(arr)
-    end
+        let # erfinv
+            arr = collect(LinRange(-1.0f0, 1.0f0, 20))
+            buffer = MtlArray(arr)
+            vec = Array(SpecialFunctions.erfinv.(buffer))
+            @test vec ≈ SpecialFunctions.erfinv.(arr)
+        end
 
-    let # expm1
-        arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
-        buffer = MtlArray(arr)
-        vec = Array(expm1.(buffer))
-        @test vec ≈ expm1.(arr)
-    end
+        let # expm1
+            arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
+            buffer = MtlArray(arr)
+            vec = Array(expm1.(buffer))
+            @test vec ≈ expm1.(arr)
+        end
 
 
-    let # nextafter
-        if Metal.is_macos(v"14")
-            N = 4
-            function nextafter_test(X, y)
-                idx = thread_position_in_grid_1d()
-                X[idx] = Metal.nextafter(X[idx], y)
-                return nothing
-            end
-            arr = rand(T, N)
-            buffer = MtlArray(arr)
-            Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T))
-            @test Array(buffer) == nextfloat.(arr)
+        let # nextafter
+            if Metal.is_macos(v"14")
+                N = 4
+                function nextafter_test(X, y)
+                    idx = thread_position_in_grid_1d()
+                    X[idx] = Metal.nextafter(X[idx], y)
+                    return nothing
+                end
+                arr = rand(T, N)
+                buffer = MtlArray(arr)
+                Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T))
+                @test Array(buffer) == nextfloat.(arr)
 
-            Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T))
-            @test Array(buffer) == arr
+                Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T))
+                @test Array(buffer) == arr
+            end
         end
     end
 end
-end
 
 ############################################################################################
 

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: 8c4f365 Previous: 1b811cb Ratio
private array/construct 24854.166666666668 ns 26482.666666666668 ns 0.94
private array/broadcast 465500 ns 457584 ns 1.02
private array/random/randn/Float32 810625 ns 857895.5 ns 0.94
private array/random/randn!/Float32 631125 ns 657709 ns 0.96
private array/random/rand!/Int64 569458 ns 564458 ns 1.01
private array/random/rand!/Float32 592520.5 ns 610416 ns 0.97
private array/random/rand/Int64 743146 ns 805354 ns 0.92
private array/random/rand/Float32 601688 ns 602791.5 ns 1.00
private array/copyto!/gpu_to_gpu 669520.5 ns 672437.5 ns 1.00
private array/copyto!/cpu_to_gpu 778208 ns 635271 ns 1.23
private array/copyto!/gpu_to_cpu 679312.5 ns 815104.5 ns 0.83
private array/accumulate/1d 1341709 ns 1324625 ns 1.01
private array/accumulate/2d 1389834 ns 1393666.5 ns 1.00
private array/iteration/findall/int 2070250 ns 2070667 ns 1.00
private array/iteration/findall/bool 1808853.5 ns 1827208 ns 0.99
private array/iteration/findfirst/int 1736771.5 ns 1688333.5 ns 1.03
private array/iteration/findfirst/bool 1679187.5 ns 1671750 ns 1.00
private array/iteration/scalar 3441584 ns 3643750 ns 0.94
private array/iteration/logical 3218750 ns 3166208 ns 1.02
private array/iteration/findmin/1d 1764375 ns 1758542 ns 1.00
private array/iteration/findmin/2d 1353250 ns 1357792 ns 1.00
private array/reductions/reduce/1d 1050666.5 ns 1043270.5 ns 1.01
private array/reductions/reduce/2d 666750 ns 663333 ns 1.01
private array/reductions/mapreduce/1d 1038167 ns 1043084 ns 1.00
private array/reductions/mapreduce/2d 670416 ns 665896 ns 1.01
private array/permutedims/4d 2545750 ns 2529437.5 ns 1.01
private array/permutedims/2d 1014916 ns 1024875 ns 0.99
private array/permutedims/3d 1594083 ns 1585208 ns 1.01
private array/copy 554709 ns 592958 ns 0.94
latency/precompile 8822746791 ns 8799224667 ns 1.00
latency/ttfp 3615673875 ns 3600655125 ns 1.00
latency/import 1236130792 ns 1231127083 ns 1.00
integration/metaldevrt 732667 ns 701416 ns 1.04
integration/byval/slices=1 1570917 ns 1566354 ns 1.00
integration/byval/slices=3 10470667 ns 10376042 ns 1.01
integration/byval/reference 1558791 ns 1610875 ns 0.97
integration/byval/slices=2 2581083 ns 2715041 ns 0.95
kernel/indexing 455500 ns 474291.5 ns 0.96
kernel/indexing_checked 456083 ns 475895.5 ns 0.96
kernel/launch 8375 ns 8208 ns 1.02
metal/synchronization/stream 14750 ns 15041 ns 0.98
metal/synchronization/context 15333 ns 15000 ns 1.02
shared array/construct 24175 ns 24145.833333333332 ns 1.00
shared array/broadcast 458250 ns 461145.5 ns 0.99
shared array/random/randn/Float32 839708 ns 813750 ns 1.03
shared array/random/randn!/Float32 631791 ns 636542 ns 0.99
shared array/random/rand!/Int64 557916 ns 568417 ns 0.98
shared array/random/rand!/Float32 594250 ns 603334 ns 0.98
shared array/random/rand/Int64 754250 ns 778417 ns 0.97
shared array/random/rand/Float32 589167 ns 616250 ns 0.96
shared array/copyto!/gpu_to_gpu 83645.5 ns 84667 ns 0.99
shared array/copyto!/cpu_to_gpu 83542 ns 83375 ns 1.00
shared array/copyto!/gpu_to_cpu 84041 ns 84625 ns 0.99
shared array/accumulate/1d 1351354 ns 1346167 ns 1.00
shared array/accumulate/2d 1392291 ns 1397291.5 ns 1.00
shared array/iteration/findall/int 1840375 ns 1800250 ns 1.02
shared array/iteration/findall/bool 1580500 ns 1590437.5 ns 0.99
shared array/iteration/findfirst/int 1401916.5 ns 1408125 ns 1.00
shared array/iteration/findfirst/bool 1377333 ns 1364584 ns 1.01
shared array/iteration/scalar 159458 ns 153583 ns 1.04
shared array/iteration/logical 2998999.5 ns 2981833.5 ns 1.01
shared array/iteration/findmin/1d 1474396 ns 1467625 ns 1.00
shared array/iteration/findmin/2d 1374541.5 ns 1369416 ns 1.00
shared array/reductions/reduce/1d 741500 ns 738166.5 ns 1.00
shared array/reductions/reduce/2d 669042 ns 663208.5 ns 1.01
shared array/reductions/mapreduce/1d 744312.5 ns 746583 ns 1.00
shared array/reductions/mapreduce/2d 673229.5 ns 670917 ns 1.00
shared array/permutedims/4d 2558541 ns 2524958.5 ns 1.01
shared array/permutedims/2d 1024062.5 ns 1028125 ns 1.00
shared array/permutedims/3d 1589917 ns 1588833 ns 1.00
shared array/copy 247958 ns 239042 ns 1.04

This comment was automatically generated by workflow using github-action-benchmark.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant