|
1 |
| -using SpecialFunctions |
2 | 1 | using Metal: metal_support
|
| 2 | +using Random |
| 3 | +using SpecialFunctions |
3 | 4 |
|
4 | 5 | @testset "arguments" begin
|
5 | 6 | @on_device dispatch_quadgroups_per_threadgroup()
|
@@ -103,71 +104,261 @@ end
|
103 | 104 |
|
104 | 105 | ############################################################################################
|
105 | 106 |
|
| 107 | +MATH_INTR_FUNCS_1_ARG = [ |
| 108 | + # Common functions |
| 109 | + # saturate, # T saturate(T x) Clamp between 0.0 and 1.0 |
| 110 | + sign, # T sign(T x) returns 0.0 if x is NaN |
| 111 | + |
| 112 | + # float math |
| 113 | + acos, # T acos(T x) |
| 114 | + asin, # T asin(T x) |
| 115 | + asinh, # T asinh(T x) |
| 116 | + atan, # T atan(T x) |
| 117 | + atanh, # T atanh(T x) |
| 118 | + ceil, # T ceil(T x) |
| 119 | + cos, # T cos(T x) |
| 120 | + cosh, # T cosh(T x) |
| 121 | + cospi, # T cospi(T x) |
| 122 | + exp, # T exp(T x) |
| 123 | + exp2, # T exp2(T x) |
| 124 | + exp10, # T exp10(T x) |
| 125 | + abs, #T [f]abs(T x) |
| 126 | + floor, # T floor(T x) |
| 127 | + Metal.fract, # T fract(T x) |
| 128 | + # ilogb, # Ti ilogb(T x) |
| 129 | + log, # T log(T x) |
| 130 | + log2, # T log2(T x) |
| 131 | + log10, # T log10(T x) |
| 132 | + # Metal.rint, # T rint(T x) # TODO: Add test. Not sure what the behaviour actually is |
| 133 | + round, # T round(T x) |
| 134 | + Metal.rsqrt, # T rsqrt(T x) |
| 135 | + sin, # T sin(T x) |
| 136 | + sinh, # T sinh(T x) |
| 137 | + sinpi, # T sinpi(T x) |
| 138 | + sqrt, # sqrt(T x) |
| 139 | + tan, # T tan(T x) |
| 140 | + tanh, # T tanh(T x) |
| 141 | + tanpi, # T tanpi(T x) |
| 142 | + trunc, # T trunc(T x) |
| 143 | +] |
| 144 | +Metal.rsqrt(x::Float16) = 1 / sqrt(x) |
| 145 | +Metal.rsqrt(x::Float32) = 1 / sqrt(x) |
| 146 | +Metal.fract(x::Float16) = mod(x, 1) |
| 147 | +Metal.fract(x::Float32) = mod(x, 1) |
| 148 | + |
| 149 | +MATH_INTR_FUNCS_2_ARG = [ |
| 150 | + # Common function |
| 151 | + # step, # T step(T edge, T x) Returns 0.0 if x < edge, otherwise it returns 1.0 |
| 152 | + |
| 153 | + # float math |
| 154 | + atan, # T atan2(T x, T y) Compute arc tangent of y over x. |
| 155 | + # fdim, # T fdim(T x, T y) |
| 156 | + max, # T [f]max(T x, T y) |
| 157 | + min, # T [f]min(T x, T y) |
| 158 | + # fmod, # T fmod(T x, T y) |
| 159 | + # frexp, # T frexp(T x, Ti &exponent) |
| 160 | + # ldexp, # T ldexp(T x, Ti k) |
| 161 | + # modf, # T modf(T x, T &intval) |
| 162 | + # nextafter, # T nextafter(T x, T y) # Metal 3.1+ |
| 163 | + # sincos, |
| 164 | + hypot, # NOT MSL but tested the same |
| 165 | +] |
| 166 | + |
| 167 | +MATH_INTR_FUNCS_3_ARG = [ |
| 168 | + # Common functions |
| 169 | + # mix, # T mix(T x, T y, T a) # x+(y-x)*a |
| 170 | + # smoothstep, # T smoothstep(T edge0, T edge1, T x) |
| 171 | + fma, # T fma(T a, T b, T c) |
| 172 | + max, # T max3(T x, T y, T z) |
| 173 | + # median3, # T median3(T x, T y, T z) |
| 174 | + min, # T min3(T x, T y, T z) |
| 175 | +] |
| 176 | + |
106 | 177 | @testset "math" begin
|
107 |
| - a = ones(Float32,1) |
108 |
| - a .* Float32(3.14) |
109 |
| - bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) |
110 |
| - vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1) |
| 178 | +# 1-arg functions |
| 179 | +@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16) |
| 180 | + cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt] |
| 181 | + rand(T, 4) |
| 182 | + else |
| 183 | + T[0.0, -0.0, rand(T), -rand(T)] |
| 184 | + end |
| 185 | + |
| 186 | + mtlarr = MtlArray(cpuarr) |
| 187 | + |
| 188 | + mtlout = fill!(similar(mtlarr), 0) |
111 | 189 |
|
112 |
| - function intr_test(arr) |
| 190 | + function kernel(res, arr) |
113 | 191 | idx = thread_position_in_grid_1d()
|
114 |
| - arr[idx] = cos(arr[idx]) |
| 192 | + res[idx] = fun(arr[idx]) |
115 | 193 | return nothing
|
116 | 194 | end
|
117 |
| - @metal intr_test(bufferA) |
118 |
| - synchronize() |
119 |
| - @test vecA ≈ cos.(a) |
| 195 | + Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr) |
| 196 | + @eval @test Array($mtlout) ≈ $fun.($cpuarr) |
| 197 | +end |
| 198 | +# 2-arg functions |
| 199 | +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG |
| 200 | + N = 4 |
| 201 | + arr1 = randn(T, N) |
| 202 | + arr2 = randn(T, N) |
| 203 | + mtlarr1 = MtlArray(arr1) |
| 204 | + mtlarr2 = MtlArray(arr2) |
| 205 | + |
| 206 | + mtlout = fill!(similar(mtlarr1), 0) |
120 | 207 |
|
121 |
| - function intr_test2(arr) |
| 208 | + function kernel(res, x, y) |
122 | 209 | idx = thread_position_in_grid_1d()
|
123 |
| - arr[idx] = Metal.rsqrt(arr[idx]) |
| 210 | + res[idx] = fun(x[idx], y[idx]) |
124 | 211 | return nothing
|
125 | 212 | end
|
126 |
| - @metal intr_test2(bufferA) |
127 |
| - synchronize() |
| 213 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) |
| 214 | + @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2) |
| 215 | +end |
| 216 | +# 3-arg functions |
| 217 | +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG |
| 218 | + N = 4 |
| 219 | + arr1 = randn(T, N) |
| 220 | + arr2 = randn(T, N) |
| 221 | + arr3 = randn(T, N) |
128 | 222 |
|
129 |
| - bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) |
130 |
| - vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1) |
| 223 | + mtlarr1 = MtlArray(arr1) |
| 224 | + mtlarr2 = MtlArray(arr2) |
| 225 | + mtlarr3 = MtlArray(arr3) |
131 | 226 |
|
132 |
| - function intr_test3(arr_sin, arr_cos) |
| 227 | + mtlout = fill!(similar(mtlarr1), 0) |
| 228 | + |
| 229 | + function kernel(res, x, y, z) |
133 | 230 | idx = thread_position_in_grid_1d()
|
134 |
| - s, c = sincos(arr_cos[idx]) |
135 |
| - arr_sin[idx] = s |
136 |
| - arr_cos[idx] = c |
| 231 | + res[idx] = fun(x[idx], y[idx], z[idx]) |
137 | 232 | return nothing
|
138 | 233 | end
|
| 234 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3) |
| 235 | + @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3) |
| 236 | +end |
| 237 | +end |
139 | 238 |
|
140 |
| - @metal intr_test3(bufferA, bufferB) |
141 |
| - synchronize() |
142 |
| - @test vecA ≈ sin.(a) |
143 |
| - @test vecB ≈ cos.(a) |
| 239 | +@testset "unique math" begin |
| 240 | +@testset "$T" for T in (Float32, Float16) |
| 241 | + let # acosh |
| 242 | + arr = T[0, rand(T, 3)...] .+ T(1) |
| 243 | + buffer = MtlArray(arr) |
| 244 | + vec = acosh.(buffer) |
| 245 | + @test Array(vec) ≈ acosh.(arr) |
| 246 | + end |
144 | 247 |
|
145 |
| - b = collect(LinRange(nextfloat(-1f0), 10f0, 20)) |
146 |
| - bufferC = MtlArray(b) |
147 |
| - vecC = Array(log1p.(bufferC)) |
148 |
| - @test vecC ≈ log1p.(b) |
| 248 | + let # sincos |
| 249 | + N = 4 |
| 250 | + arr = rand(T, N) |
| 251 | + bufferA = MtlArray(arr) |
| 252 | + bufferB = MtlArray(arr) |
| 253 | + function intr_test3(arr_sin, arr_cos) |
| 254 | + idx = thread_position_in_grid_1d() |
| 255 | + sinres, cosres = sincos(arr_cos[idx]) |
| 256 | + arr_sin[idx] = sinres |
| 257 | + arr_cos[idx] = cosres |
| 258 | + return nothing |
| 259 | + end |
| 260 | + # Broken with Float16 |
| 261 | + if T == Float16 |
| 262 | + @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) |
| 263 | + else |
| 264 | + Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) |
| 265 | + @test Array(bufferA) ≈ sin.(arr) |
| 266 | + @test Array(bufferB) ≈ cos.(arr) |
| 267 | + end |
| 268 | + end |
149 | 269 |
|
| 270 | + let # clamp |
| 271 | + N = 4 |
| 272 | + in = randn(T, N) |
| 273 | + minval = fill(T(-0.6), N) |
| 274 | + maxval = fill(T(0.6), N) |
150 | 275 |
|
151 |
| - d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
152 |
| - bufferD = MtlArray(d) |
153 |
| - vecD = Array(SpecialFunctions.erf.(bufferD)) |
154 |
| - @test vecD ≈ SpecialFunctions.erf.(d) |
| 276 | + mtlin = MtlArray(in) |
| 277 | + mtlminval = MtlArray(minval) |
| 278 | + mtlmaxval = MtlArray(maxval) |
155 | 279 |
|
| 280 | + mtlout = fill!(similar(mtlin), 0) |
| 281 | + |
| 282 | + function kernel(res, x, y, z) |
| 283 | + idx = thread_position_in_grid_1d() |
| 284 | + res[idx] = clamp(x[idx], y[idx], z[idx]) |
| 285 | + return nothing |
| 286 | + end |
| 287 | + Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval) |
| 288 | + @test Array(mtlout) == clamp.(in, minval, maxval) |
| 289 | + end |
156 | 290 |
|
157 |
| - e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
158 |
| - bufferE = MtlArray(e) |
159 |
| - vecE = Array(SpecialFunctions.erfc.(bufferE)) |
160 |
| - @test vecE ≈ SpecialFunctions.erfc.(e) |
| 291 | + let #pow |
| 292 | + N = 4 |
| 293 | + arr1 = rand(T, N) |
| 294 | + arr2 = rand(T, N) |
| 295 | + mtlarr1 = MtlArray(arr1) |
| 296 | + mtlarr2 = MtlArray(arr2) |
161 | 297 |
|
162 |
| - f = collect(LinRange(-1f0, 1f0, 20)) |
163 |
| - bufferF = MtlArray(f) |
164 |
| - vecF = Array(SpecialFunctions.erfinv.(bufferF)) |
165 |
| - @test vecF ≈ SpecialFunctions.erfinv.(f) |
| 298 | + mtlout = fill!(similar(mtlarr1), 0) |
166 | 299 |
|
167 |
| - f = collect(LinRange(nextfloat(-88f0), 88f0, 100)) |
168 |
| - bufferF = MtlArray(f) |
169 |
| - vecF = Array(expm1.(bufferF)) |
170 |
| - @test vecF ≈ expm1.(f) |
| 300 | + function kernel(res, x, y) |
| 301 | + idx = thread_position_in_grid_1d() |
| 302 | + res[idx] = x[idx]^y[idx] |
| 303 | + return nothing |
| 304 | + end |
| 305 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) |
| 306 | + @test Array(mtlout) ≈ arr1 .^ arr2 |
| 307 | + end |
| 308 | + |
| 309 | + let #powr |
| 310 | + N = 4 |
| 311 | + arr1 = rand(T, N) |
| 312 | + arr2 = rand(T, N) |
| 313 | + mtlarr1 = MtlArray(arr1) |
| 314 | + mtlarr2 = MtlArray(arr2) |
| 315 | + |
| 316 | + mtlout = fill!(similar(mtlarr1), 0) |
| 317 | + |
| 318 | + function kernel(res, x, y) |
| 319 | + idx = thread_position_in_grid_1d() |
| 320 | + res[idx] = Metal.powr(x[idx], y[idx]) |
| 321 | + return nothing |
| 322 | + end |
| 323 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) |
| 324 | + @test Array(mtlout) ≈ arr1 .^ arr2 |
| 325 | + end |
| 326 | + |
| 327 | + let # log1p |
| 328 | + arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)) |
| 329 | + buffer = MtlArray(arr) |
| 330 | + vec = Array(log1p.(buffer)) |
| 331 | + @test vec ≈ log1p.(arr) |
| 332 | + end |
| 333 | + |
| 334 | + let # erf |
| 335 | + arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
| 336 | + buffer = MtlArray(arr) |
| 337 | + vec = Array(SpecialFunctions.erf.(buffer)) |
| 338 | + @test vec ≈ SpecialFunctions.erf.(arr) |
| 339 | + end |
| 340 | + |
| 341 | + let # erfc |
| 342 | + arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
| 343 | + buffer = MtlArray(arr) |
| 344 | + vec = Array(SpecialFunctions.erfc.(buffer)) |
| 345 | + @test vec ≈ SpecialFunctions.erfc.(arr) |
| 346 | + end |
| 347 | + |
| 348 | + let # erfinv |
| 349 | + arr = collect(LinRange(-1.0f0, 1.0f0, 20)) |
| 350 | + buffer = MtlArray(arr) |
| 351 | + vec = Array(SpecialFunctions.erfinv.(buffer)) |
| 352 | + @test vec ≈ SpecialFunctions.erfinv.(arr) |
| 353 | + end |
| 354 | + |
| 355 | + let # expm1 |
| 356 | + arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)) |
| 357 | + buffer = MtlArray(arr) |
| 358 | + vec = Array(expm1.(buffer)) |
| 359 | + @test vec ≈ expm1.(arr) |
| 360 | + end |
| 361 | +end |
171 | 362 | end
|
172 | 363 |
|
173 | 364 | ############################################################################################
|
|
0 commit comments