|
| 1 | +@static if VERSION < v"1.11" |
| 2 | + using ScopedValues |
| 3 | +else |
| 4 | + using Base.ScopedValues |
| 5 | +end |
| 6 | + |
1 | 7 | using MLX
|
2 | 8 | using Random
|
3 | 9 | using Test
|
4 | 10 |
|
5 | 11 | @testset "MLXArray" begin
|
6 | 12 | Random.seed!(42)
|
7 | 13 |
|
| 14 | + device_types = [MLX.DeviceTypeCPU] |
| 15 | + if MLX.metal_is_available() |
| 16 | + push!(device_types, MLX.DeviceTypeGPU) |
| 17 | + end |
| 18 | + |
| 19 | + element_types = MLX.supported_number_types() |
| 20 | + |
8 | 21 | @test IndexStyle(MLXArray) == IndexLinear()
|
9 |
| - element_types = ( |
10 |
| - Bool, |
11 |
| - UInt8, |
12 |
| - UInt16, |
13 |
| - UInt32, |
14 |
| - UInt64, |
15 |
| - Int8, |
16 |
| - Int16, |
17 |
| - Int32, |
18 |
| - Int64, |
19 |
| - # TODO Float16, |
20 |
| - Float32, |
21 |
| - Float64, |
22 |
| - # TODO Core.BFloat16 |
23 |
| - ComplexF32, |
24 |
| - ) |
| 22 | + |
25 | 23 | array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)]
|
26 | 24 | @testset "AbstractArray interface" begin
|
27 | 25 | for T in element_types, array_size in array_sizes
|
@@ -51,29 +49,44 @@ using Test
|
51 | 49 | @test unsafe_wrap(mlx_array) == array
|
52 | 50 |
|
53 | 51 | @testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin
|
54 |
| - similar_mlx_array = similar(mlx_array) |
55 |
| - @test typeof(similar_mlx_array) == typeof(mlx_array) |
56 |
| - @test size(similar_mlx_array) == size(mlx_array) |
57 |
| - @test similar_mlx_array !== mlx_array |
| 52 | + for device_type in device_types |
| 53 | + if T ∉ MLX.supported_number_types(device_type) |
| 54 | + continue |
| 55 | + end |
| 56 | + @testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin |
| 57 | + with(MLX.device => MLX.Device(; device_type)) do |
| 58 | + similar_mlx_array = similar(mlx_array) |
| 59 | + @test typeof(similar_mlx_array) == typeof(mlx_array) |
| 60 | + @test size(similar_mlx_array) == size(mlx_array) |
| 61 | + @test similar_mlx_array !== mlx_array |
| 62 | + end |
| 63 | + end |
| 64 | + end |
58 | 65 | end
|
59 | 66 | end
|
60 | 67 | end
|
61 | 68 | end
|
| 69 | + |
62 | 70 | @testset "Unsupported Number types" begin
|
63 | 71 | @test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int})
|
64 | 72 | end
|
65 | 73 |
|
66 | 74 | @testset "Broadcasting interface" begin
|
67 |
| - for T in element_types, array_size in array_sizes |
| 75 | + for device_type in device_types, |
| 76 | + T in MLX.supported_number_types(device_type), |
| 77 | + array_size in array_sizes |
| 78 | + |
68 | 79 | N = length(array_size)
|
69 |
| - @testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size" begin |
| 80 | + @testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin |
70 | 81 | array = rand(T, array_size)
|
71 | 82 | mlx_array = MLXArray(array)
|
72 | 83 |
|
73 |
| - result = identity.(mlx_array) |
74 |
| - @test result isa MLXArray |
75 |
| - @test result == mlx_array |
76 |
| - @test result !== mlx_array |
| 84 | + with(MLX.device => MLX.Device(; device_type)) do |
| 85 | + result = identity.(mlx_array) |
| 86 | + @test result isa MLXArray |
| 87 | + @test result == mlx_array |
| 88 | + @test result !== mlx_array |
| 89 | + end |
77 | 90 | end
|
78 | 91 | end
|
79 | 92 | end
|
|
0 commit comments