From 3695b235ea3cb4d49818663baaddfcdba8645811 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 15:56:34 -0400 Subject: [PATCH] fix: use fast paths for adapt --- Project.toml | 2 -- ext/MLDataDevicesChainRulesCoreExt.jl | 3 +-- src/MLDataDevices.jl | 1 - src/public.jl | 26 +++++--------------------- test/misc_tests.jl | 13 +++++++++++++ 5 files changed, 19 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 391724d..7d94339 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "1.4.2" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -54,7 +53,6 @@ Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" -LinearAlgebra = "1.10" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index 2b8c9c8..e625dc1 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -8,8 +8,7 @@ using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) -function ChainRulesCore.rrule( - ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) +function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) dev = get_device(x) y = Adapt.adapt_storage(to, x) if dev === nothing || dev isa UnknownDevice diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index c837887..108d8bf 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -5,7 +5,6 @@ using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using Compat: @compat -using LinearAlgebra: Transpose, Adjoint abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end diff --git a/src/public.jl b/src/public.jl index 6f7c8b8..b6ee2c4 100644 --- a/src/public.jl +++ b/src/public.jl @@ -342,29 +342,13 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} - if isleaf(x) - (isbitstype(T) || Internal.special_aos(x)) && return Adapt.adapt(D, x) - return map(D, x) + if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray + return Adapt.adapt(D, x) end - return Functors.fmap(D, x; exclude=isleaf) - end - # Fast Paths else we don't get type stability - function (D::$(ldev))(x::Transpose{T, <:AbstractArray{T}}) where {T} - return transpose(D(parent(x))) - end - function (D::$(ldev))(x::Adjoint{T, <:AbstractArray{T}}) where {T} - return adjoint(D(parent(x))) - end - function (D::$(ldev))(x::PermutedDimsArray{ - T, N, perm, iperm, <:AbstractArray{T}}) where {T, N, perm, iperm} - y = D(parent(x)) - return PermutedDimsArray{eltype(y), N, perm, iperm, typeof(y)}(y) + return map(D, x) end - function (D::$(ldev))(x::Union{Tuple, NamedTuple}) - isleaf(x) && map(D, x) - return Functors.fmap(D, x; exclude=isleaf) - end + (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) isleaf(x) && return Adapt.adapt(D, x) return Functors.fmap(D, x; exclude=isleaf) @@ -418,4 +402,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct isleaf(x) = Functors.isleaf(x) isleaf(::AbstractArray{T}) where {T} = isbitstype(T) -isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false +isleaf(::Adapt.WrappedArray) = false diff --git a/test/misc_tests.jl b/test/misc_tests.jl index a1023cb..00a9905 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -206,3 +206,16 @@ end end end end + +@testset "Zygote.gradient(wrapped arrays)" begin + using Zygote + + x = rand(4, 4) + cdev = cpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64} + + gdev = gpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} +end