Skip to content

Commit f20bbeb

Browse files
committed
Fixed similar and broadcasting wrt. supported_number_types(DeviceType)
1 parent 4090c73 commit f20bbeb

File tree

4 files changed

+66
-26
lines changed

4 files changed

+66
-26
lines changed

src/device.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,31 @@
33
DeviceTypeGPU = 1
44
end
55

6+
function supported_number_types(device_type::DeviceType = DeviceTypeCPU)
7+
types = [
8+
Bool,
9+
UInt8,
10+
UInt16,
11+
UInt32,
12+
UInt64,
13+
Int8,
14+
Int16,
15+
Int32,
16+
Int64,
17+
# TODO Float16,
18+
Float32,
19+
# TODO Core.BFloat16
20+
ComplexF32,
21+
]
22+
if device_type == DeviceTypeCPU
23+
return vcat(types, [Float64])
24+
elseif device_type == DeviceTypeGPU
25+
return types
26+
else
27+
throw(ArgumentError("Unsupported device type: $device_type"))
28+
end
29+
end
30+
631
mutable struct Device
732
mlx_device::Wrapper.mlx_device
833

test/array_tests.jl

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
1+
@static if VERSION < v"1.11"
2+
using ScopedValues
3+
else
4+
using Base.ScopedValues
5+
end
6+
17
using MLX
28
using Random
39
using Test
410

511
@testset "MLXArray" begin
612
Random.seed!(42)
713

14+
device_types = [MLX.DeviceTypeCPU]
15+
if MLX.metal_is_available()
16+
push!(device_types, MLX.DeviceTypeGPU)
17+
end
18+
19+
element_types = MLX.supported_number_types()
20+
821
@test IndexStyle(MLXArray) == IndexLinear()
9-
element_types = (
10-
Bool,
11-
UInt8,
12-
UInt16,
13-
UInt32,
14-
UInt64,
15-
Int8,
16-
Int16,
17-
Int32,
18-
Int64,
19-
# TODO Float16,
20-
Float32,
21-
Float64,
22-
# TODO Core.BFloat16
23-
ComplexF32,
24-
)
22+
2523
array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (2, 2), (1, 1, 1)]
2624
@testset "AbstractArray interface" begin
2725
for T in element_types, array_size in array_sizes
@@ -51,29 +49,44 @@ using Test
5149
@test unsafe_wrap(mlx_array) == array
5250

5351
@testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin
54-
similar_mlx_array = similar(mlx_array)
55-
@test typeof(similar_mlx_array) == typeof(mlx_array)
56-
@test size(similar_mlx_array) == size(mlx_array)
57-
@test similar_mlx_array !== mlx_array
52+
for device_type in device_types
53+
if T MLX.supported_number_types(device_type)
54+
continue
55+
end
56+
@testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin
57+
with(MLX.device => MLX.Device(; device_type)) do
58+
similar_mlx_array = similar(mlx_array)
59+
@test typeof(similar_mlx_array) == typeof(mlx_array)
60+
@test size(similar_mlx_array) == size(mlx_array)
61+
@test similar_mlx_array !== mlx_array
62+
end
63+
end
64+
end
5865
end
5966
end
6067
end
6168
end
69+
6270
@testset "Unsupported Number types" begin
6371
@test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int})
6472
end
6573

6674
@testset "Broadcasting interface" begin
67-
for T in element_types, array_size in array_sizes
75+
for device_type in device_types,
76+
T in MLX.supported_number_types(device_type),
77+
array_size in array_sizes
78+
6879
N = length(array_size)
69-
@testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size" begin
80+
@testset "broadcast(identity, ::$MLXArray{$T, $N}), array_size=$array_size, $device_type" begin
7081
array = rand(T, array_size)
7182
mlx_array = MLXArray(array)
7283

73-
result = identity.(mlx_array)
74-
@test result isa MLXArray
75-
@test result == mlx_array
76-
@test result !== mlx_array
84+
with(MLX.device => MLX.Device(; device_type)) do
85+
result = identity.(mlx_array)
86+
@test result isa MLXArray
87+
@test result == mlx_array
88+
@test result !== mlx_array
89+
end
7790
end
7891
end
7992
end

test/device_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
else
44
using Base.ScopedValues
55
end
6+
67
using MLX
78
using Test
89

test/stream_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
else
44
using Base.ScopedValues
55
end
6+
67
using MLX
78
using Test
89

0 commit comments

Comments
 (0)