diff --git a/src/MLX.jl b/src/MLX.jl index 8d90e7a..0d4b07a 100644 --- a/src/MLX.jl +++ b/src/MLX.jl @@ -10,6 +10,8 @@ export MLXArray, MLXException, MLXMatrix, MLXVecOrMat, MLXVector include(joinpath(@__DIR__, "Wrapper.jl")) +include(joinpath(@__DIR__, "utils.jl")) + include(joinpath(@__DIR__, "array.jl")) include(joinpath(@__DIR__, "device.jl")) include(joinpath(@__DIR__, "error_handling.jl")) diff --git a/src/array.jl b/src/array.jl index c675d1a..0e09b46 100644 --- a/src/array.jl +++ b/src/array.jl @@ -42,10 +42,14 @@ function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number}) end function MLXArray{T, N}(array::AbstractArray{T, N}) where {T, N} - shape = Ref(Cint.(reverse(size(array)))) + is_column_major = + storage_order(array; preferred_order = ArrayStorageOrderRow) == + ArrayStorageOrderColumn + array_row_major = is_column_major ? permutedims(array, reverse(1:ndims(array))) : array + shape = collect(Cint.(size(array))) dtype = convert(Wrapper.mlx_dtype, T) - mlx_array = GC.@preserve array shape Wrapper.mlx_array_new_data( - pointer(array), shape, N, dtype + mlx_array = GC.@preserve array_row_major shape Wrapper.mlx_array_new_data( + pointer(array_row_major), pointer(shape), Cint(N), dtype ) return MLXArray{T, N}(mlx_array) end @@ -67,13 +71,11 @@ const MLXVecOrMat{T} = Union{MLXVector{T}, MLXMatrix{T}} function Base.size(array::MLXArray) return Tuple( Int.( - reverse( - unsafe_wrap( - Vector{Cint}, - Wrapper.mlx_array_shape(array.mlx_array), - Wrapper.mlx_array_ndim(array.mlx_array), - ), - ) + unsafe_wrap( + Vector{Cint}, + Wrapper.mlx_array_shape(array.mlx_array), + Wrapper.mlx_array_ndim(array.mlx_array), + ), ), ) end @@ -83,21 +85,20 @@ Base.IndexStyle(::Type{<:MLXArray}) = IndexLinear() Base.getindex(array::MLXArray, i::Int) = getindex(unsafe_wrap(array), i) function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N} - return setindex!(unsafe_wrap(array), v, i) + setindex!(unsafe_wrap(array), v, i) + return array end -# StridedArray interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays +# Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays function Base.strides(array::MLXArray) return Tuple( Int.( - reverse( - unsafe_wrap( - Vector{Csize_t}, - Wrapper.mlx_array_strides(array.mlx_array), - Wrapper.mlx_array_ndim(array.mlx_array), - ), - ) + unsafe_wrap( + Vector{Csize_t}, + Wrapper.mlx_array_strides(array.mlx_array), + Wrapper.mlx_array_ndim(array.mlx_array), + ), ), ) end @@ -134,13 +135,25 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N} throw(ArgumentError("Unsupported type: $T")) end + Wrapper.mlx_array_eval(array.mlx_array) return mlx_array_data(array.mlx_array) end +Base.elsize(::Type{MLXArray{T, N}}) where {T, N} = sizeof(T) + function Base.elsize(array::MLXArray{T, N}) where {T, N} return Int(Wrapper.mlx_array_itemsize(array.mlx_array)) end function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N} - return unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, array), size(array)) + is_column_major = storage_order(array) == ArrayStorageOrderColumn + size_column_major = is_column_major ? size(array) : reverse(size(array)) + wrapped_array = unsafe_wrap( + Array, Base.unsafe_convert(Ptr{T}, array), size_column_major + ) + if is_column_major + return PermutedDimsArray(wrapped_array, 1:ndims(array)) + else + return PermutedDimsArray(wrapped_array, reverse(1:ndims(array))) + end end diff --git a/src/device.jl b/src/device.jl index a084374..9a81865 100644 --- a/src/device.jl +++ b/src/device.jl @@ -3,6 +3,31 @@ DeviceTypeGPU = 1 end +function supported_number_types(device_type::DeviceType = DeviceTypeCPU) + types = [ + Bool, + UInt8, + UInt16, + UInt32, + UInt64, + Int8, + Int16, + Int32, + Int64, + # TODO Float16, + Float32, + # TODO Core.BFloat16 + ComplexF32, + ] + if device_type == DeviceTypeCPU + return vcat(types, [Float64]) + elseif device_type == DeviceTypeGPU + return types + else + throw(ArgumentError("Unsupported device type: $device_type")) + end +end + mutable struct Device mlx_device::Wrapper.mlx_device diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..b57d060 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,31 @@ +@enum ArrayStorageOrder begin + ArrayStorageOrderColumn + ArrayStorageOrderRow +end + +size_to_strides_column(sz::Dims) = Base.size_to_strides(1, sz...) + +size_to_strides_row(sz::Dims) = reverse(Base.size_to_strides(1, reverse(sz)...)) + +""" + storage_order(a::AbstractArray; preferred_order::ArrayStorageOrder = ArrayStorageOrderColumn) + +Returns the storage order of `a`, or `preferred_order` if `a` is a scalar or vector. +""" +function storage_order( + a::AbstractArray{T, N}; preferred_order::ArrayStorageOrder = ArrayStorageOrderColumn +) where {T, N} + if N < 2 + return preferred_order + end + + if strides(a) == size_to_strides_column(size(a)) + return ArrayStorageOrderColumn + elseif strides(a) == size_to_strides_row(size(a)) + return ArrayStorageOrderRow + else + throw( + ArgumentError("Unexpected strides $(strides(a)) for array of size $(size(a))") + ) + end +end diff --git a/test/array_tests.jl b/test/array_tests.jl index d94a78d..025688a 100644 --- a/test/array_tests.jl +++ b/test/array_tests.jl @@ -3,48 +3,63 @@ using Test @testset "MLXArray" begin @test IndexStyle(MLXArray) == IndexLinear() - element_types = ( - Bool, - UInt8, - UInt16, - UInt32, - UInt64, - Int8, - Int16, - Int32, - Int64, - # TODO Float16, - Float32, - Float64, - # TODO Core.BFloat16 - ComplexF32, - ) - array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)] - for T in element_types, array_size in array_sizes - N = length(array_size) - @testset "$MLXArray{$T, $N} $array_size" begin - array = ones(T, array_size) - if N > 2 || N == 0 - mlx_array = MLXArray(array) - elseif N > 1 - mlx_array = MLXMatrix(array) - else - mlx_array = MLXVector(array) - end - @test eltype(mlx_array) == T - @test length(mlx_array) == length(array) - @test ndims(mlx_array) == ndims(array) - @test size(mlx_array) == size(array) - @test strides(mlx_array) == strides(array) - @test Base.elsize(mlx_array) == Base.elsize(array) - - if N > 0 - @test getindex(mlx_array, 1) == T(1) - array[1] = T(1) - @test setindex!(mlx_array, T(1), 1) == array + + array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (3, 2), (4, 3, 2)] + + @testset "AbstractArray interface" begin + element_types = MLX.supported_number_types() + + for T in element_types, array_size in array_sizes + N = length(array_size) + @testset "$MLXArray{$T, $N}, array_size=$array_size" begin + array = ones(T, array_size) + if N > 2 || N == 0 + mlx_array = MLXArray(array) + elseif N > 1 + mlx_array = MLXMatrix(array) + else + mlx_array = MLXVector(array) + end + + @test eltype(mlx_array) == T + @test length(mlx_array) == length(array) + @test ndims(mlx_array) == ndims(array) + @test size(mlx_array) == size(array) + + if N > 0 + @test getindex(mlx_array, 1) == T(1) + array[1] = T(1) + @test setindex!(mlx_array, T(1), 1) == array + end end + end + end + @testset "Strided array interface" begin + element_types = MLX.supported_number_types(MLX.DeviceTypeGPU) # TODO Excluding Float64 + + for T in element_types, array_size in array_sizes + N = length(array_size) + @testset "$MLXArray{$T, $N}, array_size=$array_size" begin + array = ones(T, array_size) + mlx_array = MLXArray(array) - @test unsafe_wrap(mlx_array) == array + if N > 0 + @test strides(mlx_array) == + reverse(strides(permutedims(array, reverse(1:ndims(array))))) + else + @test strides(mlx_array) == strides(array) + end + @test Base.unsafe_convert(Ptr{T}, mlx_array) isa Ptr{T} + @test unsafe_wrap(mlx_array) == array + @test Base.elsize(mlx_array) == Base.elsize(array) + + another_mlx_array = MLXArray(mlx_array) + @test another_mlx_array == mlx_array + @test strides(another_mlx_array) == strides(mlx_array) + end + end + for T in element_types + @test Base.elsize(MLXArray{T, 0}) == Base.elsize(Array{T, 0}) end end @testset "Unsupported Number types" begin