Skip to content

Implemented broadcasting for MLXArray #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ ScopedValues = "1"
julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Random", "Test"]
26 changes: 26 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@
return setindex!(unsafe_wrap(array), v, i)
end

function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N}
stream = get_stream()
result_ref = Ref(Wrapper.mlx_array_new())
Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream)
return MLXArray{T, N}(result_ref[])
end

# StridedArray interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays

function Base.strides(array::MLXArray)
Expand Down Expand Up @@ -134,6 +141,7 @@
throw(ArgumentError("Unsupported type: $T"))
end

Wrapper.mlx_array_eval(array.mlx_array)
return mlx_array_data(array.mlx_array)
end

Expand All @@ -144,3 +152,21 @@
function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N}
return unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, array), size(array))
end

# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting

Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}()

function Base.similar(
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement}
) where {TElement}
first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args)
function first_mlx_array(args::Tuple)
return first_mlx_array(first_mlx_array(args[1]), Base.tail(args))
end
first_mlx_array(x) = x
first_mlx_array(::Tuple{}) = nothing

Check warning on line 168 in src/array.jl

View check run for this annotation

Codecov / codecov/patch

src/array.jl#L168

Added line #L168 was not covered by tests
first_mlx_array(a::MLXArray, _) = a
first_mlx_array(::Any, rest) = first_mlx_array(rest)

Check warning on line 170 in src/array.jl

View check run for this annotation

Codecov / codecov/patch

src/array.jl#L170

Added line #L170 was not covered by tests
return similar(first_mlx_array(bc))
end
25 changes: 25 additions & 0 deletions src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
DeviceTypeGPU = 1
end

function supported_number_types(device_type::DeviceType = DeviceTypeCPU)
types = [
Bool,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
# TODO Float16,
Float32,
# TODO Core.BFloat16
ComplexF32,
]
if device_type == DeviceTypeCPU
return vcat(types, [Float64])
elseif device_type == DeviceTypeGPU
return types
else
throw(ArgumentError("Unsupported device type: $device_type"))

Check warning on line 27 in src/device.jl

View check run for this annotation

Codecov / codecov/patch

src/device.jl#L27

Added line #L27 was not covered by tests
end
end

mutable struct Device
mlx_device::Wrapper.mlx_device

Expand Down
120 changes: 80 additions & 40 deletions test/array_tests.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,93 @@
@static if VERSION < v"1.11"
using ScopedValues
else
using Base.ScopedValues
end

using MLX
using Random
using Test

@testset "MLXArray" begin
Random.seed!(42)

device_types = [MLX.DeviceTypeCPU]
if MLX.metal_is_available()
push!(device_types, MLX.DeviceTypeGPU)
end

element_types = MLX.supported_number_types()

@test IndexStyle(MLXArray) == IndexLinear()
element_types = (
Bool,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
# TODO Float16,
Float32,
Float64,
# TODO Core.BFloat16
ComplexF32,
)

array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)]
for T in element_types, array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{$T, $N} $array_size" begin
array = ones(T, array_size)
if N > 2 || N == 0
mlx_array = MLXArray(array)
elseif N > 1
mlx_array = MLXMatrix(array)
else
mlx_array = MLXVector(array)
end
@test eltype(mlx_array) == T
@test length(mlx_array) == length(array)
@test ndims(mlx_array) == ndims(array)
@test size(mlx_array) == size(array)
@test strides(mlx_array) == strides(array)
@test Base.elsize(mlx_array) == Base.elsize(array)

if N > 0
@test getindex(mlx_array, 1) == T(1)
array[1] = T(1)
@test setindex!(mlx_array, T(1), 1) == array
end
@testset "AbstractArray interface" begin
for T in element_types, array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{$T, $N}, array_size=$array_size" begin
array = ones(T, array_size)
if N > 2 || N == 0
mlx_array = MLXArray(array)
elseif N > 1
mlx_array = MLXMatrix(array)
else
mlx_array = MLXVector(array)
end
@test eltype(mlx_array) == T
@test length(mlx_array) == length(array)
@test ndims(mlx_array) == ndims(array)
@test size(mlx_array) == size(array)
@test strides(mlx_array) == strides(array)
@test Base.elsize(mlx_array) == Base.elsize(array)

if N > 0
@test getindex(mlx_array, 1) == T(1)
array[1] = T(1)
@test setindex!(mlx_array, T(1), 1) == array
end

@test unsafe_wrap(mlx_array) == array
@test unsafe_wrap(mlx_array) == array

@testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin
for device_type in device_types
if T ∉ MLX.supported_number_types(device_type)
continue
end
@testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin
with(MLX.device => MLX.Device(; device_type)) do
similar_mlx_array = similar(mlx_array)
@test typeof(similar_mlx_array) == typeof(mlx_array)
@test size(similar_mlx_array) == size(mlx_array)
@test similar_mlx_array !== mlx_array
end
end
end
end
end
end
end

@testset "Unsupported Number types" begin
@test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int})
end

@testset "Broadcasting interface" begin
for device_type in device_types,
T in MLX.supported_number_types(device_type),
array_size in array_sizes

N = length(array_size)
@testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin
array = rand(T, array_size)
mlx_array = MLXArray(array)

with(MLX.device => MLX.Device(; device_type)) do
result = identity.(mlx_array)
@test result isa MLXArray
@test result == mlx_array
@test result !== mlx_array
end
end
end
end
end
1 change: 1 addition & 0 deletions test/device_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
else
using Base.ScopedValues
end

using MLX
using Test

Expand Down
1 change: 1 addition & 0 deletions test/stream_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
else
using Base.ScopedValues
end

using MLX
using Test

Expand Down