Skip to content

Commit b66fd85

Browse files
committed
Fixup
1 parent 6c783f8 commit b66fd85

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

src/array.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number})
4242
end
4343

4444
function MLXArray{T, N}(array::AbstractArray{T, N}) where {T, N}
45-
is_column_major = storage_order(array, StorageOrderRow) == StorageOrderColumn
45+
is_column_major =
46+
storage_order(array; preferred_order = ArrayStorageOrderRow) ==
47+
ArrayStorageOrderColumn
4648
array_row_major = is_column_major ? permutedims(array, reverse(1:ndims(array))) : array
4749
shape = collect(Cint.(size(array)))
4850
dtype = convert(Wrapper.mlx_dtype, T)
@@ -144,7 +146,7 @@ function Base.elsize(array::MLXArray{T, N}) where {T, N}
144146
end
145147

146148
function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N}
147-
is_column_major = storage_order(array) == StorageOrderColumn
149+
is_column_major = storage_order(array) == ArrayStorageOrderColumn
148150
size_column_major = is_column_major ? size(array) : reverse(size(array))
149151
wrapped_array = unsafe_wrap(
150152
Array, Base.unsafe_convert(Ptr{T}, array), size_column_major

src/utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
@enum StorageOrder begin
2-
StorageOrderColumn
3-
StorageOrderRow
1+
@enum ArrayStorageOrder begin
2+
ArrayStorageOrderColumn
3+
ArrayStorageOrderRow
44
end
55

66
size_to_strides_column(sz::Dims) = Base.size_to_strides(1, sz...)
77

88
size_to_strides_row(sz::Dims) = reverse(Base.size_to_strides(1, reverse(sz)...))
99

1010
"""
11-
storage_order(a::AbstractArray, preferred_order::StorageOrder = StorageOrderColumn)
11+
storage_order(a::AbstractArray; preferred_order::ArrayStorageOrder = ArrayStorageOrderColumn)
1212
1313
Returns the storage order of `a`, or `preferred_order` if `a` is a scalar or vector.
1414
"""
1515
function storage_order(
16-
a::AbstractArray{T, N}, preferred_order::StorageOrder = StorageOrderColumn
16+
a::AbstractArray{T, N}; preferred_order::ArrayStorageOrder = ArrayStorageOrderColumn
1717
) where {T, N}
1818
if N < 2
1919
return preferred_order
2020
end
2121

2222
if strides(a) == size_to_strides_column(size(a))
23-
return StorageOrderColumn
23+
return ArrayStorageOrderColumn
2424
elseif strides(a) == size_to_strides_row(size(a))
25-
return StorageOrderRow
25+
return ArrayStorageOrderRow
2626
else
2727
throw(
2828
ArgumentError("Unexpected strides $(strides(a)) for array of size $(size(a))")

0 commit comments

Comments
 (0)