Skip to content

MLXArray: Added Float16 and BFloat16 support #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gen/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ options = load_options(joinpath(@__DIR__, "generator.toml"))
args = get_default_args()
push!(args, "-I$(joinpath(MLX_C_jll.artifact_dir, "include"))")

# Ensure Float16 and BFloat16 methods are generated - even on platforms without support
push!(args, "-DHAS_FLOAT16")
push!(args, "-DHAS_BFLOAT16")

headers = [
joinpath(include_dir, header) for
header in readdir(include_dir) if endswith(header, ".h")
Expand Down
16 changes: 16 additions & 0 deletions src/Wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ function mlx_array_item_complex64(res, arr)
ccall((:mlx_array_item_complex64, libmlxc), Cint, (Ptr{ComplexF32}, mlx_array), res, arr)
end

function mlx_array_item_float16(res, arr)
ccall((:mlx_array_item_float16, libmlxc), Cint, (Ptr{Cint}, mlx_array), res, arr)
end

function mlx_array_item_bfloat16(res, arr)
ccall((:mlx_array_item_bfloat16, libmlxc), Cint, (Ptr{Cint}, mlx_array), res, arr)
end

function mlx_array_data_bool(arr)
ccall((:mlx_array_data_bool, libmlxc), Ptr{Bool}, (mlx_array,), arr)
end
Expand Down Expand Up @@ -220,6 +228,14 @@ function mlx_array_data_complex64(arr)
ccall((:mlx_array_data_complex64, libmlxc), Ptr{ComplexF32}, (mlx_array,), arr)
end

function mlx_array_data_float16(arr)
ccall((:mlx_array_data_float16, libmlxc), Ptr{Cint}, (mlx_array,), arr)
end

function mlx_array_data_bfloat16(arr)
ccall((:mlx_array_data_bfloat16, libmlxc), Ptr{Cint}, (mlx_array,), arr)
end

struct mlx_closure_
ctx::Ptr{Cvoid}
end
Expand Down
12 changes: 9 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ mutable struct MLXArray{T, N} <: AbstractArray{T, N}
end

function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number})
@static if VERSION >= v"1.11"
if type == Core.BFloat16
return Wrapper.MLX_BFLOAT16
end
end
if type == Bool
return Wrapper.MLX_BOOL
elseif type == UInt8
Expand All @@ -31,7 +36,6 @@ function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number})
return Wrapper.MLX_FLOAT16
elseif type == Float32
return Wrapper.MLX_FLOAT32
# TODO Handle Wrapper.MLX_BFLOAT16
elseif type == ComplexF32
return Wrapper.MLX_COMPLEX64 # MLX_COMPLEX64 is a complex of Float32
else
Expand Down Expand Up @@ -120,10 +124,12 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N}
mlx_array_data = Wrapper.mlx_array_data_int32
elseif T == Int64
mlx_array_data = Wrapper.mlx_array_data_int64
# TODO generate wrapper on system with HAS_FLOAT16
elseif T == Float16
mlx_array_data = Wrapper.mlx_array_data_float16
elseif T == Float32
mlx_array_data = Wrapper.mlx_array_data_float32
# TODO generate wrapper on system with HAS_BFLOAT16
elseif T == Core.BFloat16
mlx_array_data = Wrapper.mlx_array_data_bfloat16
elseif T == ComplexF32
mlx_array_data = Wrapper.mlx_array_data_complex64
else
Expand Down
3 changes: 3 additions & 0 deletions test/array_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ using Test
ComplexF32,
)
array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)]
@static if VERSION >= v"1.11"
element_types = (element_types..., Core.BFloat16)
end
for T in element_types, array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{$T, $N} $array_size" begin
Expand Down
Loading