-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to float intrinsics #531
Open
christiangnrd
wants to merge
8
commits into
main
Choose a base branch
from
intrinsics
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
`tanpi` is in Julia since 1.10 so allsupported versions have it
Also clean up the different tests
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 2f1d3b78..95eb9a6a 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -173,210 +173,210 @@ MATH_INTR_FUNCS_3_ARG = [
]
@testset "math" begin
-# 1-arg functions
-@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
- cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
- rand(T, 4)
- else
- T[0.0, -0.0, rand(T), -rand(T)]
- end
+ # 1-arg functions
+ @testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
+ cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
+ rand(T, 4)
+ else
+ T[0.0, -0.0, rand(T), -rand(T)]
+ end
- mtlarr = MtlArray(cpuarr)
+ mtlarr = MtlArray(cpuarr)
- mtlout = fill!(similar(mtlarr), 0)
+ mtlout = fill!(similar(mtlarr), 0)
- function kernel(res, arr)
+ function kernel(res, arr)
idx = thread_position_in_grid_1d()
- res[idx] = fun(arr[idx])
+ res[idx] = fun(arr[idx])
return nothing
end
- Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
- @eval @test Array($mtlout) ≈ $fun.($cpuarr)
-end
-# 2-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
- N = 4
- arr1 = randn(T, N)
- arr2 = randn(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
+ @eval @test Array($mtlout) ≈ $fun.($cpuarr)
+ end
+ # 2-arg functions
+ @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
+ N = 4
+ arr1 = randn(T, N)
+ arr2 = randn(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
+ function kernel(res, x, y)
idx = thread_position_in_grid_1d()
- res[idx] = fun(x[idx], y[idx])
+ res[idx] = fun(x[idx], y[idx])
return nothing
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
-end
-# 3-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
- N = 4
- arr1 = randn(T, N)
- arr2 = randn(T, N)
- arr3 = randn(T, N)
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
+ end
+ # 3-arg functions
+ @testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
+ N = 4
+ arr1 = randn(T, N)
+ arr2 = randn(T, N)
+ arr3 = randn(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
- mtlarr3 = MtlArray(arr3)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
+ mtlarr3 = MtlArray(arr3)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y, z)
+ function kernel(res, x, y, z)
idx = thread_position_in_grid_1d()
- res[idx] = fun(x[idx], y[idx], z[idx])
+ res[idx] = fun(x[idx], y[idx], z[idx])
return nothing
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
- @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
-end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
+ @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
+ end
end
@testset "unique math" begin
-@testset "$T" for T in (Float32, Float16)
- let # acosh
- arr = T[0, rand(T, 3)...] .+ T(1)
- buffer = MtlArray(arr)
- vec = acosh.(buffer)
- @test Array(vec) ≈ acosh.(arr)
- end
-
- let # sincos
- N = 4
- arr = rand(T, N)
- bufferA = MtlArray(arr)
- bufferB = MtlArray(arr)
- function intr_test3(arr_sin, arr_cos)
- idx = thread_position_in_grid_1d()
- sinres, cosres = sincos(arr_cos[idx])
- arr_sin[idx] = sinres
- arr_cos[idx] = cosres
- return nothing
+ @testset "$T" for T in (Float32, Float16)
+ let # acosh
+ arr = T[0, rand(T, 3)...] .+ T(1)
+ buffer = MtlArray(arr)
+ vec = acosh.(buffer)
+ @test Array(vec) ≈ acosh.(arr)
end
- # Broken with Float16
- if T == Float16
- @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
- else
- Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
- @test Array(bufferA) ≈ sin.(arr)
- @test Array(bufferB) ≈ cos.(arr)
+
+ let # sincos
+ N = 4
+ arr = rand(T, N)
+ bufferA = MtlArray(arr)
+ bufferB = MtlArray(arr)
+ function intr_test3(arr_sin, arr_cos)
+ idx = thread_position_in_grid_1d()
+ sinres, cosres = sincos(arr_cos[idx])
+ arr_sin[idx] = sinres
+ arr_cos[idx] = cosres
+ return nothing
+ end
+ # Broken with Float16
+ if T == Float16
+ @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+ else
+ Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
+ @test Array(bufferA) ≈ sin.(arr)
+ @test Array(bufferB) ≈ cos.(arr)
+ end
end
- end
- let # clamp
- N = 4
- in = randn(T, N)
- minval = fill(T(-0.6), N)
- maxval = fill(T(0.6), N)
+ let # clamp
+ N = 4
+ in = randn(T, N)
+ minval = fill(T(-0.6), N)
+ maxval = fill(T(0.6), N)
- mtlin = MtlArray(in)
- mtlminval = MtlArray(minval)
- mtlmaxval = MtlArray(maxval)
+ mtlin = MtlArray(in)
+ mtlminval = MtlArray(minval)
+ mtlmaxval = MtlArray(maxval)
- mtlout = fill!(similar(mtlin), 0)
+ mtlout = fill!(similar(mtlin), 0)
- function kernel(res, x, y, z)
- idx = thread_position_in_grid_1d()
- res[idx] = clamp(x[idx], y[idx], z[idx])
- return nothing
+ function kernel(res, x, y, z)
+ idx = thread_position_in_grid_1d()
+ res[idx] = clamp(x[idx], y[idx], z[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
+ @test Array(mtlout) == clamp.(in, minval, maxval)
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
- @test Array(mtlout) == clamp.(in, minval, maxval)
- end
- let #pow
- N = 4
- arr1 = rand(T, N)
- arr2 = rand(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ let #pow
+ N = 4
+ arr1 = rand(T, N)
+ arr2 = rand(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
- idx = thread_position_in_grid_1d()
- res[idx] = x[idx]^y[idx]
- return nothing
+ function kernel(res, x, y)
+ idx = thread_position_in_grid_1d()
+ res[idx] = x[idx]^y[idx]
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @test Array(mtlout) ≈ arr1 .^ arr2
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @test Array(mtlout) ≈ arr1 .^ arr2
- end
- let #powr
- N = 4
- arr1 = rand(T, N)
- arr2 = rand(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ let #powr
+ N = 4
+ arr1 = rand(T, N)
+ arr2 = rand(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
- idx = thread_position_in_grid_1d()
- res[idx] = Metal.powr(x[idx], y[idx])
- return nothing
+ function kernel(res, x, y)
+ idx = thread_position_in_grid_1d()
+ res[idx] = Metal.powr(x[idx], y[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @test Array(mtlout) ≈ arr1 .^ arr2
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @test Array(mtlout) ≈ arr1 .^ arr2
- end
- let # log1p
- arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(log1p.(buffer))
- @test vec ≈ log1p.(arr)
- end
+ let # log1p
+ arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(log1p.(buffer))
+ @test vec ≈ log1p.(arr)
+ end
- let # erf
- arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(SpecialFunctions.erf.(buffer))
- @test vec ≈ SpecialFunctions.erf.(arr)
- end
+ let # erf
+ arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(SpecialFunctions.erf.(buffer))
+ @test vec ≈ SpecialFunctions.erf.(arr)
+ end
- let # erfc
- arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(SpecialFunctions.erfc.(buffer))
- @test vec ≈ SpecialFunctions.erfc.(arr)
- end
+ let # erfc
+ arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(SpecialFunctions.erfc.(buffer))
+ @test vec ≈ SpecialFunctions.erfc.(arr)
+ end
- let # erfinv
- arr = collect(LinRange(-1.0f0, 1.0f0, 20))
- buffer = MtlArray(arr)
- vec = Array(SpecialFunctions.erfinv.(buffer))
- @test vec ≈ SpecialFunctions.erfinv.(arr)
- end
+ let # erfinv
+ arr = collect(LinRange(-1.0f0, 1.0f0, 20))
+ buffer = MtlArray(arr)
+ vec = Array(SpecialFunctions.erfinv.(buffer))
+ @test vec ≈ SpecialFunctions.erfinv.(arr)
+ end
- let # expm1
- arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
- buffer = MtlArray(arr)
- vec = Array(expm1.(buffer))
- @test vec ≈ expm1.(arr)
- end
+ let # expm1
+ arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
+ buffer = MtlArray(arr)
+ vec = Array(expm1.(buffer))
+ @test vec ≈ expm1.(arr)
+ end
- let # nextafter
- if Metal.is_macos(v"14")
- N = 4
- function nextafter_test(X, y)
- idx = thread_position_in_grid_1d()
- X[idx] = Metal.nextafter(X[idx], y)
- return nothing
- end
- arr = rand(T, N)
- buffer = MtlArray(arr)
- Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T))
- @test Array(buffer) == nextfloat.(arr)
+ let # nextafter
+ if Metal.is_macos(v"14")
+ N = 4
+ function nextafter_test(X, y)
+ idx = thread_position_in_grid_1d()
+ X[idx] = Metal.nextafter(X[idx], y)
+ return nothing
+ end
+ arr = rand(T, N)
+ buffer = MtlArray(arr)
+ Metal.@sync @metal threads = N nextafter_test(buffer, typemax(T))
+ @test Array(buffer) == nextfloat.(arr)
- Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T))
- @test Array(buffer) == arr
+ Metal.@sync @metal threads = N nextafter_test(buffer, typemin(T))
+ @test Array(buffer) == arr
+ end
end
end
end
-end
############################################################################################
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
Benchmark suite | Current: 8c4f365 | Previous: 1b811cb | Ratio |
---|---|---|---|
private array/construct |
24854.166666666668 ns |
26482.666666666668 ns |
0.94 |
private array/broadcast |
465500 ns |
457584 ns |
1.02 |
private array/random/randn/Float32 |
810625 ns |
857895.5 ns |
0.94 |
private array/random/randn!/Float32 |
631125 ns |
657709 ns |
0.96 |
private array/random/rand!/Int64 |
569458 ns |
564458 ns |
1.01 |
private array/random/rand!/Float32 |
592520.5 ns |
610416 ns |
0.97 |
private array/random/rand/Int64 |
743146 ns |
805354 ns |
0.92 |
private array/random/rand/Float32 |
601688 ns |
602791.5 ns |
1.00 |
private array/copyto!/gpu_to_gpu |
669520.5 ns |
672437.5 ns |
1.00 |
private array/copyto!/cpu_to_gpu |
778208 ns |
635271 ns |
1.23 |
private array/copyto!/gpu_to_cpu |
679312.5 ns |
815104.5 ns |
0.83 |
private array/accumulate/1d |
1341709 ns |
1324625 ns |
1.01 |
private array/accumulate/2d |
1389834 ns |
1393666.5 ns |
1.00 |
private array/iteration/findall/int |
2070250 ns |
2070667 ns |
1.00 |
private array/iteration/findall/bool |
1808853.5 ns |
1827208 ns |
0.99 |
private array/iteration/findfirst/int |
1736771.5 ns |
1688333.5 ns |
1.03 |
private array/iteration/findfirst/bool |
1679187.5 ns |
1671750 ns |
1.00 |
private array/iteration/scalar |
3441584 ns |
3643750 ns |
0.94 |
private array/iteration/logical |
3218750 ns |
3166208 ns |
1.02 |
private array/iteration/findmin/1d |
1764375 ns |
1758542 ns |
1.00 |
private array/iteration/findmin/2d |
1353250 ns |
1357792 ns |
1.00 |
private array/reductions/reduce/1d |
1050666.5 ns |
1043270.5 ns |
1.01 |
private array/reductions/reduce/2d |
666750 ns |
663333 ns |
1.01 |
private array/reductions/mapreduce/1d |
1038167 ns |
1043084 ns |
1.00 |
private array/reductions/mapreduce/2d |
670416 ns |
665896 ns |
1.01 |
private array/permutedims/4d |
2545750 ns |
2529437.5 ns |
1.01 |
private array/permutedims/2d |
1014916 ns |
1024875 ns |
0.99 |
private array/permutedims/3d |
1594083 ns |
1585208 ns |
1.01 |
private array/copy |
554709 ns |
592958 ns |
0.94 |
latency/precompile |
8822746791 ns |
8799224667 ns |
1.00 |
latency/ttfp |
3615673875 ns |
3600655125 ns |
1.00 |
latency/import |
1236130792 ns |
1231127083 ns |
1.00 |
integration/metaldevrt |
732667 ns |
701416 ns |
1.04 |
integration/byval/slices=1 |
1570917 ns |
1566354 ns |
1.00 |
integration/byval/slices=3 |
10470667 ns |
10376042 ns |
1.01 |
integration/byval/reference |
1558791 ns |
1610875 ns |
0.97 |
integration/byval/slices=2 |
2581083 ns |
2715041 ns |
0.95 |
kernel/indexing |
455500 ns |
474291.5 ns |
0.96 |
kernel/indexing_checked |
456083 ns |
475895.5 ns |
0.96 |
kernel/launch |
8375 ns |
8208 ns |
1.02 |
metal/synchronization/stream |
14750 ns |
15041 ns |
0.98 |
metal/synchronization/context |
15333 ns |
15000 ns |
1.02 |
shared array/construct |
24175 ns |
24145.833333333332 ns |
1.00 |
shared array/broadcast |
458250 ns |
461145.5 ns |
0.99 |
shared array/random/randn/Float32 |
839708 ns |
813750 ns |
1.03 |
shared array/random/randn!/Float32 |
631791 ns |
636542 ns |
0.99 |
shared array/random/rand!/Int64 |
557916 ns |
568417 ns |
0.98 |
shared array/random/rand!/Float32 |
594250 ns |
603334 ns |
0.98 |
shared array/random/rand/Int64 |
754250 ns |
778417 ns |
0.97 |
shared array/random/rand/Float32 |
589167 ns |
616250 ns |
0.96 |
shared array/copyto!/gpu_to_gpu |
83645.5 ns |
84667 ns |
0.99 |
shared array/copyto!/cpu_to_gpu |
83542 ns |
83375 ns |
1.00 |
shared array/copyto!/gpu_to_cpu |
84041 ns |
84625 ns |
0.99 |
shared array/accumulate/1d |
1351354 ns |
1346167 ns |
1.00 |
shared array/accumulate/2d |
1392291 ns |
1397291.5 ns |
1.00 |
shared array/iteration/findall/int |
1840375 ns |
1800250 ns |
1.02 |
shared array/iteration/findall/bool |
1580500 ns |
1590437.5 ns |
0.99 |
shared array/iteration/findfirst/int |
1401916.5 ns |
1408125 ns |
1.00 |
shared array/iteration/findfirst/bool |
1377333 ns |
1364584 ns |
1.01 |
shared array/iteration/scalar |
159458 ns |
153583 ns |
1.04 |
shared array/iteration/logical |
2998999.5 ns |
2981833.5 ns |
1.01 |
shared array/iteration/findmin/1d |
1474396 ns |
1467625 ns |
1.00 |
shared array/iteration/findmin/2d |
1374541.5 ns |
1369416 ns |
1.00 |
shared array/reductions/reduce/1d |
741500 ns |
738166.5 ns |
1.00 |
shared array/reductions/reduce/2d |
669042 ns |
663208.5 ns |
1.01 |
shared array/reductions/mapreduce/1d |
744312.5 ns |
746583 ns |
1.00 |
shared array/reductions/mapreduce/2d |
673229.5 ns |
670917 ns |
1.00 |
shared array/permutedims/4d |
2558541 ns |
2524958.5 ns |
1.01 |
shared array/permutedims/2d |
1024062.5 ns |
1028125 ns |
1.00 |
shared array/permutedims/3d |
1589917 ns |
1588833 ns |
1.00 |
shared array/copy |
247958 ns |
239042 ns |
1.04 |
This comment was automatically generated by workflow using github-action-benchmark.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Replaces #529
Description of each commit in order:
tanpi
has been in base Julia since 1.10, so switch from@device_function
to@device_override
.clamp
&sign
nextafter
atan
max
& 'min'