diff --git a/gen/generator.jl b/gen/generator.jl index 0be1de9..6bd45af 100644 --- a/gen/generator.jl +++ b/gen/generator.jl @@ -10,6 +10,10 @@ options = load_options(joinpath(@__DIR__, "generator.toml")) args = get_default_args() push!(args, "-I$(joinpath(MLX_C_jll.artifact_dir, "include"))") +# Ensure Float16 and BFloat16 methods are generated - even on platforms without support +push!(args, "-DHAS_FLOAT16") +push!(args, "-DHAS_BFLOAT16") + headers = [ joinpath(include_dir, header) for header in readdir(include_dir) if endswith(header, ".h") diff --git a/src/Wrapper.jl b/src/Wrapper.jl index 8dc93ec..1c49de5 100644 --- a/src/Wrapper.jl +++ b/src/Wrapper.jl @@ -176,6 +176,14 @@ function mlx_array_item_complex64(res, arr) ccall((:mlx_array_item_complex64, libmlxc), Cint, (Ptr{ComplexF32}, mlx_array), res, arr) end +function mlx_array_item_float16(res, arr) + ccall((:mlx_array_item_float16, libmlxc), Cint, (Ptr{Cint}, mlx_array), res, arr) +end + +function mlx_array_item_bfloat16(res, arr) + ccall((:mlx_array_item_bfloat16, libmlxc), Cint, (Ptr{Cint}, mlx_array), res, arr) +end + function mlx_array_data_bool(arr) ccall((:mlx_array_data_bool, libmlxc), Ptr{Bool}, (mlx_array,), arr) end @@ -220,6 +228,14 @@ function mlx_array_data_complex64(arr) ccall((:mlx_array_data_complex64, libmlxc), Ptr{ComplexF32}, (mlx_array,), arr) end +function mlx_array_data_float16(arr) + ccall((:mlx_array_data_float16, libmlxc), Ptr{Cint}, (mlx_array,), arr) +end + +function mlx_array_data_bfloat16(arr) + ccall((:mlx_array_data_bfloat16, libmlxc), Ptr{Cint}, (mlx_array,), arr) +end + struct mlx_closure_ ctx::Ptr{Cvoid} end diff --git a/src/array.jl b/src/array.jl index 57c41c3..1241879 100644 --- a/src/array.jl +++ b/src/array.jl @@ -9,6 +9,11 @@ mutable struct MLXArray{T, N} <: AbstractArray{T, N} end function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number}) + @static if VERSION >= v"1.11" + if type == Core.BFloat16 + return Wrapper.MLX_BFLOAT16 + end + end if type == Bool return Wrapper.MLX_BOOL elseif type == UInt8 @@ -31,7 +36,6 @@ function Base.convert(::Type{Wrapper.mlx_dtype}, type::Type{<:Number}) return Wrapper.MLX_FLOAT16 elseif type == Float32 return Wrapper.MLX_FLOAT32 - # TODO Handle Wrapper.MLX_BFLOAT16 elseif type == ComplexF32 return Wrapper.MLX_COMPLEX64 # MLX_COMPLEX64 is a complex of Float32 else @@ -120,10 +124,12 @@ function Base.unsafe_convert(::Type{Ptr{T}}, array::MLXArray{T, N}) where {T, N} mlx_array_data = Wrapper.mlx_array_data_int32 elseif T == Int64 mlx_array_data = Wrapper.mlx_array_data_int64 - # TODO generate wrapper on system with HAS_FLOAT16 + elseif T == Float16 + mlx_array_data = Wrapper.mlx_array_data_float16 elseif T == Float32 mlx_array_data = Wrapper.mlx_array_data_float32 - # TODO generate wrapper on system with HAS_BFLOAT16 + elseif T == Core.BFloat16 + mlx_array_data = Wrapper.mlx_array_data_bfloat16 elseif T == ComplexF32 mlx_array_data = Wrapper.mlx_array_data_complex64 else diff --git a/test/array_tests.jl b/test/array_tests.jl index 4a37605..51d55a1 100644 --- a/test/array_tests.jl +++ b/test/array_tests.jl @@ -19,6 +19,9 @@ using Test ComplexF32, ) array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)] + @static if VERSION >= v"1.11" + element_types = (element_types..., Core.BFloat16) + end for T in element_types, array_size in array_sizes N = length(array_size) @testset "$MLXArray{$T, $N} $array_size" begin