Skip to content

Commit 1d862bc

Browse files
committed
WIP
TODO: Needs guarding against platforms without Float16 and/or BFloat16 in the MLX JLL.
1 parent 1dc2544 commit 1d862bc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/array.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,12 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N}
124124
mlx_array_data = Wrapper.mlx_array_data_int32
125125
elseif T == Int64
126126
mlx_array_data = Wrapper.mlx_array_data_int64
127-
# TODO generate wrapper on system with HAS_FLOAT16
127+
elseif T == Float16
128+
mlx_array_data = Wrapper.mlx_array_data_float16
128129
elseif T == Float32
129130
mlx_array_data = Wrapper.mlx_array_data_float32
130-
# TODO generate wrapper on system with HAS_BFLOAT16
131+
elseif T == Core.BFloat16
132+
mlx_array_data = Wrapper.mlx_array_data_bfloat16
131133
elseif T == ComplexF32
132134
mlx_array_data = Wrapper.mlx_array_data_complex64
133135
else

0 commit comments

Comments
 (0)