Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: adapt ranges to JuliaGPU/Adapt.jl#86
Browse files Browse the repository at this point in the history
avik-pal committed Oct 25, 2024
1 parent c8ef590 commit bb8388d
Showing 8 changed files with 13 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.9.6, 1"
Adapt = "4"
Adapt = "4.1"
CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.15"
3 changes: 2 additions & 1 deletion ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -12,7 +12,8 @@ function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::Abst
dev = get_device(x)
y = Adapt.adapt_storage(to, x)
if dev === nothing || dev isa UnknownDevice
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
dev isa UnknownDevice &&
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ)
return y, ∇adapt_storage_unknown
else
9 changes: 0 additions & 9 deletions src/public.jl
Original file line number Diff line number Diff line change
@@ -347,7 +347,6 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
end
return map(D, x)
end

(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
isleaf(x) && return Adapt.adapt(D, x)
@@ -376,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
end
end

Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x
# Prevent Ambiguity
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

"""
isleaf(x) -> Bool
4 changes: 2 additions & 2 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
4 changes: 2 additions & 2 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
4 changes: 2 additions & 2 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
4 changes: 2 additions & 2 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
@@ -55,12 +55,12 @@ end
gdev = gpu_device()
if !(gdev isa MetalDevice) # On intel devices causes problems
x = randn(10)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, gdev, x)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt, gdev, x)
@test ∂dev === nothing
@test ∂x ones(10)

x = randn(10) |> gdev
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, cpu_device(), x)
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt, cpu_device(), x)
@test ∂dev === nothing
@test ∂x gdev(ones(10))
@test get_device(∂x) isa parameterless_type(typeof(gdev))
4 changes: 2 additions & 2 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.range isa aType
@test ps_xpu.range isa AbstractRange
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
@@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.range isa Array
@test ps_cpu.range isa AbstractRange
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG

0 comments on commit bb8388d

Please sign in to comment.