Skip to content

Commit 9fcd4c1

Browse files
committed
Implemented broadcasting for MLXArray
Also, added necessary eval of MLX array data in Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N})
1 parent 585a561 commit 9fcd4c1

File tree

3 files changed

+78
-24
lines changed

3 files changed

+78
-24
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ ScopedValues = "1"
1515
julia = "1"
1616

1717
[extras]
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1920

2021
[targets]
21-
test = ["Test"]
22+
test = ["Random", "Test"]

src/array.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N}
8484
return setindex!(unsafe_wrap(array), v, i)
8585
end
8686

87+
function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N}
88+
stream = get_stream()
89+
result_ref = Ref(Wrapper.mlx_array_new())
90+
Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream)
91+
return MLXArray{T, N}(result_ref[])
92+
end
93+
8794
# StridedArray interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays
8895

8996
function Base.strides(array::MLXArray)
@@ -130,6 +137,7 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N}
130137
throw(ArgumentError("Unsupported type: $T"))
131138
end
132139

140+
Wrapper.mlx_array_eval(array.mlx_array)
133141
return mlx_array_data(array.mlx_array)
134142
end
135143

@@ -140,3 +148,21 @@ end
140148
function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N}
141149
return unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, array), size(array))
142150
end
151+
152+
# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
153+
154+
Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}()
155+
156+
function Base.similar(
157+
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement}
158+
) where {TElement}
159+
first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args)
160+
function first_mlx_array(args::Tuple)
161+
return first_mlx_array(first_mlx_array(args[1]), Base.tail(args))
162+
end
163+
first_mlx_array(x) = x
164+
first_mlx_array(::Tuple{}) = nothing
165+
first_mlx_array(a::MLXArray, _) = a
166+
first_mlx_array(::Any, rest) = first_mlx_array(rest)
167+
return similar(first_mlx_array(bc))
168+
end

test/array_tests.jl

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using MLX
2+
using Random
23
using Test
34

45
@testset "MLXArray" begin
6+
Random.seed!(42)
7+
58
@test IndexStyle(MLXArray) == IndexLinear()
69
element_types = (
710
Bool,
@@ -19,34 +22,58 @@ using Test
1922
ComplexF32,
2023
)
2124
array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)]
22-
for T in element_types, array_size in array_sizes
23-
N = length(array_size)
24-
@testset "$MLXArray{$T, $N} $array_size" begin
25-
array = ones(T, array_size)
26-
if N > 2 || N == 0
27-
mlx_array = MLXArray(array)
28-
elseif N > 1
29-
mlx_array = MLXMatrix(array)
30-
else
31-
mlx_array = MLXVector(array)
32-
end
33-
@test eltype(mlx_array) == T
34-
@test length(mlx_array) == length(array)
35-
@test ndims(mlx_array) == ndims(array)
36-
@test size(mlx_array) == size(array)
37-
@test strides(mlx_array) == strides(array)
38-
@test Base.elsize(mlx_array) == Base.elsize(array)
25+
@testset "AbstractArray interface" begin
26+
for T in element_types, array_size in array_sizes
27+
N = length(array_size)
28+
@testset "$MLXArray{$T, $N}, array_size=$array_size" begin
29+
array = ones(T, array_size)
30+
if N > 2 || N == 0
31+
mlx_array = MLXArray(array)
32+
elseif N > 1
33+
mlx_array = MLXMatrix(array)
34+
else
35+
mlx_array = MLXVector(array)
36+
end
37+
@test eltype(mlx_array) == T
38+
@test length(mlx_array) == length(array)
39+
@test ndims(mlx_array) == ndims(array)
40+
@test size(mlx_array) == size(array)
41+
@test strides(mlx_array) == strides(array)
42+
@test Base.elsize(mlx_array) == Base.elsize(array)
3943

40-
if N > 0
41-
@test getindex(mlx_array, 1) == T(1)
42-
array[1] = T(1)
43-
@test setindex!(mlx_array, T(1), 1) == array
44-
end
44+
if N > 0
45+
@test getindex(mlx_array, 1) == T(1)
46+
array[1] = T(1)
47+
@test setindex!(mlx_array, T(1), 1) == array
48+
end
49+
50+
@test unsafe_wrap(mlx_array) == array
4551

46-
@test unsafe_wrap(mlx_array) == array
52+
@testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin
53+
similar_mlx_array = similar(mlx_array)
54+
@test typeof(similar_mlx_array) == typeof(mlx_array)
55+
@test size(similar_mlx_array) == size(mlx_array)
56+
@test similar_mlx_array !== mlx_array
57+
end
58+
end
4759
end
4860
end
4961
for T in [Float64]
5062
@test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, T)
5163
end
64+
65+
@testset "Broadcasting interface" begin
66+
for T in element_types, array_size in array_sizes
67+
N = length(array_size)
68+
@testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size" begin
69+
array = rand(T, array_size)
70+
mlx_array = MLXArray(array)
71+
72+
result = identity.(mlx_array)
73+
@test result isa MLXArray
74+
@test result == mlx_array
75+
@test result !== mlx_array
76+
end
77+
end
78+
end
5279
end

0 commit comments

Comments
 (0)