diff --git a/Project.toml b/Project.toml index 507b6b0..4b0e56f 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,8 @@ ScopedValues = "1" julia = "1" [extras] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Random", "Test"] diff --git a/src/array.jl b/src/array.jl index c675d1a..99e8658 100644 --- a/src/array.jl +++ b/src/array.jl @@ -86,6 +86,13 @@ function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N} return setindex!(unsafe_wrap(array), v, i) end +function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N} + stream = get_stream() + result_ref = Ref(Wrapper.mlx_array_new()) + Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream) + return MLXArray{T, N}(result_ref[]) +end + # StridedArray interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays function Base.strides(array::MLXArray) @@ -134,6 +141,7 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N} throw(ArgumentError("Unsupported type: $T")) end + Wrapper.mlx_array_eval(array.mlx_array) return mlx_array_data(array.mlx_array) end @@ -144,3 +152,21 @@ end function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N} return unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, array), size(array)) end + +# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting + +Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}() + +function Base.similar( + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement} +) where {TElement} + first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args) + function first_mlx_array(args::Tuple) + return first_mlx_array(first_mlx_array(args[1]), Base.tail(args)) + end + first_mlx_array(x) = x + first_mlx_array(::Tuple{}) = nothing + first_mlx_array(a::MLXArray, _) = a + first_mlx_array(::Any, rest) = first_mlx_array(rest) + return similar(first_mlx_array(bc)) +end diff --git a/src/device.jl b/src/device.jl index a084374..9a81865 100644 --- a/src/device.jl +++ b/src/device.jl @@ -3,6 +3,31 @@ DeviceTypeGPU = 1 end +function supported_number_types(device_type::DeviceType = DeviceTypeCPU) + types = [ + Bool, + UInt8, + UInt16, + UInt32, + UInt64, + Int8, + Int16, + Int32, + Int64, + # TODO Float16, + Float32, + # TODO Core.BFloat16 + ComplexF32, + ] + if device_type == DeviceTypeCPU + return vcat(types, [Float64]) + elseif device_type == DeviceTypeGPU + return types + else + throw(ArgumentError("Unsupported device type: $device_type")) + end +end + mutable struct Device mlx_device::Wrapper.mlx_device diff --git a/test/array_tests.jl b/test/array_tests.jl index d94a78d..2fc3520 100644 --- a/test/array_tests.jl +++ b/test/array_tests.jl @@ -1,53 +1,93 @@ +@static if VERSION < v"1.11" + using ScopedValues +else + using Base.ScopedValues +end + using MLX +using Random using Test @testset "MLXArray" begin + Random.seed!(42) + + device_types = [MLX.DeviceTypeCPU] + if MLX.metal_is_available() + push!(device_types, MLX.DeviceTypeGPU) + end + + element_types = MLX.supported_number_types() + @test IndexStyle(MLXArray) == IndexLinear() - element_types = ( - Bool, - UInt8, - UInt16, - UInt32, - UInt64, - Int8, - Int16, - Int32, - Int64, - # TODO Float16, - Float32, - Float64, - # TODO Core.BFloat16 - ComplexF32, - ) + array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)] - for T in element_types, array_size in array_sizes - N = length(array_size) - @testset "$MLXArray{$T, $N} $array_size" begin - array = ones(T, array_size) - if N > 2 || N == 0 - mlx_array = MLXArray(array) - elseif N > 1 - mlx_array = MLXMatrix(array) - else - mlx_array = MLXVector(array) - end - @test eltype(mlx_array) == T - @test length(mlx_array) == length(array) - @test ndims(mlx_array) == ndims(array) - @test size(mlx_array) == size(array) - @test strides(mlx_array) == strides(array) - @test Base.elsize(mlx_array) == Base.elsize(array) - - if N > 0 - @test getindex(mlx_array, 1) == T(1) - array[1] = T(1) - @test setindex!(mlx_array, T(1), 1) == array - end + @testset "AbstractArray interface" begin + for T in element_types, array_size in array_sizes + N = length(array_size) + @testset "$MLXArray{$T, $N}, array_size=$array_size" begin + array = ones(T, array_size) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + @test eltype(mlx_array) == T + @test length(mlx_array) == length(array) + @test ndims(mlx_array) == ndims(array) + @test size(mlx_array) == size(array) + @test strides(mlx_array) == strides(array) + @test Base.elsize(mlx_array) == Base.elsize(array) + + if N > 0 + @test getindex(mlx_array, 1) == T(1) + array[1] = T(1) + @test setindex!(mlx_array, T(1), 1) == array + end - @test unsafe_wrap(mlx_array) == array + @test unsafe_wrap(mlx_array) == array + + @testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin + for device_type in device_types + if T ∉ MLX.supported_number_types(device_type) + continue + end + @testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin + with(MLX.device => MLX.Device(; device_type)) do + similar_mlx_array = similar(mlx_array) + @test typeof(similar_mlx_array) == typeof(mlx_array) + @test size(similar_mlx_array) == size(mlx_array) + @test similar_mlx_array !== mlx_array + end + end + end + end + end end end + @testset "Unsupported Number types" begin @test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int}) end + + @testset "Broadcasting interface" begin + for device_type in device_types, + T in MLX.supported_number_types(device_type), + array_size in array_sizes + + N = length(array_size) + @testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin + array = rand(T, array_size) + mlx_array = MLXArray(array) + + with(MLX.device => MLX.Device(; device_type)) do + result = identity.(mlx_array) + @test result isa MLXArray + @test result == mlx_array + @test result !== mlx_array + end + end + end + end end diff --git a/test/device_tests.jl b/test/device_tests.jl index 4051cab..3a7dbab 100644 --- a/test/device_tests.jl +++ b/test/device_tests.jl @@ -3,6 +3,7 @@ else using Base.ScopedValues end + using MLX using Test diff --git a/test/stream_tests.jl b/test/stream_tests.jl index 28b106f..b5a1429 100644 --- a/test/stream_tests.jl +++ b/test/stream_tests.jl @@ -3,6 +3,7 @@ else using Base.ScopedValues end + using MLX using Test