Skip to content

Fixed MLXArray storage order #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MLX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ export MLXArray, MLXException, MLXMatrix, MLXVecOrMat, MLXVector

include(joinpath(@__DIR__, "Wrapper.jl"))

include(joinpath(@__DIR__, "utils.jl"))

include(joinpath(@__DIR__, "array.jl"))
include(joinpath(@__DIR__, "device.jl"))
include(joinpath(@__DIR__, "error_handling.jl"))
Expand Down
53 changes: 33 additions & 20 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number})
end

function MLXArray{T, N}(array::AbstractArray{T, N}) where {T, N}
shape = Ref(Cint.(reverse(size(array))))
is_column_major =
storage_order(array; preferred_order = ArrayStorageOrderRow) ==
ArrayStorageOrderColumn
array_row_major = is_column_major ? permutedims(array, reverse(1:ndims(array))) : array
shape = collect(Cint.(size(array)))
dtype = convert(Wrapper.mlx_dtype, T)
mlx_array = GC.@preserve array shape Wrapper.mlx_array_new_data(
pointer(array), shape, N, dtype
mlx_array = GC.@preserve array_row_major shape Wrapper.mlx_array_new_data(
pointer(array_row_major), pointer(shape), Cint(N), dtype
)
return MLXArray{T, N}(mlx_array)
end
Expand All @@ -67,13 +71,11 @@ const MLXVecOrMat{T} = Union{MLXVector{T}, MLXMatrix{T}}
function Base.size(array::MLXArray)
return Tuple(
Int.(
reverse(
unsafe_wrap(
Vector{Cint},
Wrapper.mlx_array_shape(array.mlx_array),
Wrapper.mlx_array_ndim(array.mlx_array),
),
)
unsafe_wrap(
Vector{Cint},
Wrapper.mlx_array_shape(array.mlx_array),
Wrapper.mlx_array_ndim(array.mlx_array),
),
),
)
end
Expand All @@ -83,21 +85,20 @@ Base.IndexStyle(::Type{<:MLXArray}) = IndexLinear()
Base.getindex(array::MLXArray, i::Int) = getindex(unsafe_wrap(array), i)

function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N}
return setindex!(unsafe_wrap(array), v, i)
setindex!(unsafe_wrap(array), v, i)
return array
end

# StridedArray interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays
# Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays

function Base.strides(array::MLXArray)
return Tuple(
Int.(
reverse(
unsafe_wrap(
Vector{Csize_t},
Wrapper.mlx_array_strides(array.mlx_array),
Wrapper.mlx_array_ndim(array.mlx_array),
),
)
unsafe_wrap(
Vector{Csize_t},
Wrapper.mlx_array_strides(array.mlx_array),
Wrapper.mlx_array_ndim(array.mlx_array),
),
),
)
end
Expand Down Expand Up @@ -134,13 +135,25 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N}
throw(ArgumentError("Unsupported type: $T"))
end

Wrapper.mlx_array_eval(array.mlx_array)
return mlx_array_data(array.mlx_array)
end

Base.elsize(::Type{MLXArray{T, N}}) where {T, N} = sizeof(T)

function Base.elsize(array::MLXArray{T, N}) where {T, N}
return Int(Wrapper.mlx_array_itemsize(array.mlx_array))
end

function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N}
return unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, array), size(array))
is_column_major = storage_order(array) == ArrayStorageOrderColumn
size_column_major = is_column_major ? size(array) : reverse(size(array))
wrapped_array = unsafe_wrap(
Array, Base.unsafe_convert(Ptr{T}, array), size_column_major
)
if is_column_major
return PermutedDimsArray(wrapped_array, 1:ndims(array))
else
return PermutedDimsArray(wrapped_array, reverse(1:ndims(array)))
end
end
25 changes: 25 additions & 0 deletions src/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
DeviceTypeGPU = 1
end

function supported_number_types(device_type::DeviceType = DeviceTypeCPU)
types = [
Bool,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
# TODO Float16,
Float32,
# TODO Core.BFloat16
ComplexF32,
]
if device_type == DeviceTypeCPU
return vcat(types, [Float64])
elseif device_type == DeviceTypeGPU
return types
else
throw(ArgumentError("Unsupported device type: $device_type"))

Check warning on line 27 in src/device.jl

View check run for this annotation

Codecov / codecov/patch

src/device.jl#L27

Added line #L27 was not covered by tests
end
end

mutable struct Device
mlx_device::Wrapper.mlx_device

Expand Down
31 changes: 31 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@enum ArrayStorageOrder begin
ArrayStorageOrderColumn
ArrayStorageOrderRow
end

size_to_strides_column(sz::Dims) = Base.size_to_strides(1, sz...)

size_to_strides_row(sz::Dims) = reverse(Base.size_to_strides(1, reverse(sz)...))

"""
storage_order(a::AbstractArray; preferred_order::ArrayStorageOrder = ArrayStorageOrderColumn)

Returns the storage order of `a`, or `preferred_order` if `a` is a scalar or vector.
"""
function storage_order(
a::AbstractArray{T, N}; preferred_order::ArrayStorageOrder = ArrayStorageOrderColumn
) where {T, N}
if N < 2
return preferred_order
end

if strides(a) == size_to_strides_column(size(a))
return ArrayStorageOrderColumn
elseif strides(a) == size_to_strides_row(size(a))
return ArrayStorageOrderRow
else
throw(

Check warning on line 27 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L27

Added line #L27 was not covered by tests
ArgumentError("Unexpected strides $(strides(a)) for array of size $(size(a))")
)
end
end
95 changes: 55 additions & 40 deletions test/array_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,63 @@ using Test

@testset "MLXArray" begin
@test IndexStyle(MLXArray) == IndexLinear()
element_types = (
Bool,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
# TODO Float16,
Float32,
Float64,
# TODO Core.BFloat16
ComplexF32,
)
array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)]
for T in element_types, array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{$T, $N} $array_size" begin
array = ones(T, array_size)
if N > 2 || N == 0
mlx_array = MLXArray(array)
elseif N > 1
mlx_array = MLXMatrix(array)
else
mlx_array = MLXVector(array)
end
@test eltype(mlx_array) == T
@test length(mlx_array) == length(array)
@test ndims(mlx_array) == ndims(array)
@test size(mlx_array) == size(array)
@test strides(mlx_array) == strides(array)
@test Base.elsize(mlx_array) == Base.elsize(array)

if N > 0
@test getindex(mlx_array, 1) == T(1)
array[1] = T(1)
@test setindex!(mlx_array, T(1), 1) == array

array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (3, 2), (4, 3, 2)]

@testset "AbstractArray interface" begin
element_types = MLX.supported_number_types()

for T in element_types, array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{$T, $N}, array_size=$array_size" begin
array = ones(T, array_size)
if N > 2 || N == 0
mlx_array = MLXArray(array)
elseif N > 1
mlx_array = MLXMatrix(array)
else
mlx_array = MLXVector(array)
end

@test eltype(mlx_array) == T
@test length(mlx_array) == length(array)
@test ndims(mlx_array) == ndims(array)
@test size(mlx_array) == size(array)

if N > 0
@test getindex(mlx_array, 1) == T(1)
array[1] = T(1)
@test setindex!(mlx_array, T(1), 1) == array
end
end
end
end
@testset "Strided array interface" begin
element_types = MLX.supported_number_types(MLX.DeviceTypeGPU) # TODO Excluding Float64

for T in element_types, array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{$T, $N}, array_size=$array_size" begin
array = ones(T, array_size)
mlx_array = MLXArray(array)

@test unsafe_wrap(mlx_array) == array
if N > 0
@test strides(mlx_array) ==
reverse(strides(permutedims(array, reverse(1:ndims(array)))))
else
@test strides(mlx_array) == strides(array)
end
@test Base.unsafe_convert(Ptr{T}, mlx_array) isa Ptr{T}
@test unsafe_wrap(mlx_array) == array
@test Base.elsize(mlx_array) == Base.elsize(array)

another_mlx_array = MLXArray(mlx_array)
@test another_mlx_array == mlx_array
@test strides(another_mlx_array) == strides(mlx_array)
end
end
for T in element_types
@test Base.elsize(MLXArray{T, 0}) == Base.elsize(Array{T, 0})
end
end
@testset "Unsupported Number types" begin
Expand Down
Loading