1
1
using MLX
2
+ using Random
2
3
using Test
3
4
4
5
@testset " MLXArray" begin
6
+ Random. seed! (42 )
7
+
5
8
@test IndexStyle (MLXArray) == IndexLinear ()
6
9
element_types = (
7
10
Bool,
@@ -19,34 +22,58 @@ using Test
19
22
ComplexF32,
20
23
)
21
24
array_sizes = [(), (1 ,), (2 ,), (1 , 1 ), (2 , 1 ), (2 , 2 ), (1 , 1 , 1 )]
22
- for T in element_types, array_size in array_sizes
23
- N = length (array_size)
24
- @testset " $MLXArray {$T , $N } $array_size " begin
25
- array = ones (T, array_size)
26
- if N > 2 || N == 0
27
- mlx_array = MLXArray (array)
28
- elseif N > 1
29
- mlx_array = MLXMatrix (array)
30
- else
31
- mlx_array = MLXVector (array)
32
- end
33
- @test eltype (mlx_array) == T
34
- @test length (mlx_array) == length (array)
35
- @test ndims (mlx_array) == ndims (array)
36
- @test size (mlx_array) == size (array)
37
- @test strides (mlx_array) == strides (array)
38
- @test Base. elsize (mlx_array) == Base. elsize (array)
25
+ @testset " AbstractArray interface" begin
26
+ for T in element_types, array_size in array_sizes
27
+ N = length (array_size)
28
+ @testset " $MLXArray {$T , $N }, array_size=$array_size " begin
29
+ array = ones (T, array_size)
30
+ if N > 2 || N == 0
31
+ mlx_array = MLXArray (array)
32
+ elseif N > 1
33
+ mlx_array = MLXMatrix (array)
34
+ else
35
+ mlx_array = MLXVector (array)
36
+ end
37
+ @test eltype (mlx_array) == T
38
+ @test length (mlx_array) == length (array)
39
+ @test ndims (mlx_array) == ndims (array)
40
+ @test size (mlx_array) == size (array)
41
+ @test strides (mlx_array) == strides (array)
42
+ @test Base. elsize (mlx_array) == Base. elsize (array)
39
43
40
- if N > 0
41
- @test getindex (mlx_array, 1 ) == T (1 )
42
- array[1 ] = T (1 )
43
- @test setindex! (mlx_array, T (1 ), 1 ) == array
44
- end
44
+ if N > 0
45
+ @test getindex (mlx_array, 1 ) == T (1 )
46
+ array[1 ] = T (1 )
47
+ @test setindex! (mlx_array, T (1 ), 1 ) == array
48
+ end
49
+
50
+ @test unsafe_wrap (mlx_array) == array
45
51
46
- @test unsafe_wrap (mlx_array) == array
52
+ @testset " similar(::$MLXArray {$T , $N }), array_size=$array_size " begin
53
+ similar_mlx_array = similar (mlx_array)
54
+ @test typeof (similar_mlx_array) == typeof (mlx_array)
55
+ @test size (similar_mlx_array) == size (mlx_array)
56
+ @test similar_mlx_array != = mlx_array
57
+ end
58
+ end
47
59
end
48
60
end
49
61
for T in [Float64]
50
62
@test_throws ArgumentError convert (MLX. Wrapper. mlx_dtype, T)
51
63
end
64
+
65
+ @testset " Broadcasting interface" begin
66
+ for T in element_types, array_size in array_sizes
67
+ N = length (array_size)
68
+ @testset " broadcast(identity, ::$MLXArray {$T , $N }), array_size=$array_size " begin
69
+ array = rand (T, array_size)
70
+ mlx_array = MLXArray (array)
71
+
72
+ result = identity .(mlx_array)
73
+ @test result isa MLXArray
74
+ @test result == mlx_array
75
+ @test result != = mlx_array
76
+ end
77
+ end
78
+ end
52
79
end
0 commit comments