Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix: minor load time patch + remove UnrolledUtilities #75

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.0"
version = "1.1.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -29,6 +28,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
Expand Down Expand Up @@ -57,7 +57,6 @@ RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
cuDNN = "1.3"
julia = "1.10"
Expand Down
19 changes: 19 additions & 0 deletions ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module MLDataDevicesChainRulesCoreExt

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable

using MLDataDevices: AbstractDevice, get_device, get_device_type

@non_differentiable get_device(::Any)
@non_differentiable get_device_type(::Any)

function ChainRulesCore.rrule(
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end

end
3 changes: 0 additions & 3 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
module MLDataDevices

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent
using Functors: Functors, fleaves
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random

const CRC = ChainRulesCore

abstract type AbstractDevice <: Function end
abstract type AbstractGPUDevice <: AbstractDevice end

Expand Down
31 changes: 30 additions & 1 deletion src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module Internal
using Functors: fmap
using Preferences: load_preference
using Random: AbstractRNG
using UnrolledUtilities: unrolled_mapreduce

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES,
Expand Down Expand Up @@ -150,6 +149,34 @@ for op in (:get_device, :get_device_type)
end
end

function unrolled_mapreduce(f::F, op::O, itr) where {F, O}
return unrolled_mapreduce(f, op, itr, static_length(itr))
end

function unrolled_mapreduce(::F, ::O, _, ::Val{0}) where {F, O}
error("Cannot unroll over an empty iterator.")
end

unrolled_mapreduce(f::F, ::O, itr, ::Val{1}) where {F, O} = f(only(itr))

@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N}
syms = [gensym("f_itr_$(i)") for i in 1:N]
op_syms = [gensym("op_$(i)") for i in 1:(N - 1)]
f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N]
combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))]
for i in 2:(N - 1)
push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1]))))
end
return quote
$(Expr(:meta, :inline))
$(Expr(:inbounds, true))
$(Expr(:block, f_applied...))
$(Expr(:inbounds, :pop))
$(Expr(:block, combine_expr...))
return $(op_syms[end])
end
end

function unsafe_free_internal!(x::AbstractArray)
unsafe_free_internal!(MLDataDevices.get_device_type(x), x)
return
Expand All @@ -162,4 +189,6 @@ function unsafe_free!(x)
return
end

static_length(t::Tuple) = Val(length(t))

end
18 changes: 3 additions & 15 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,9 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
end

for op in (:get_device, :get_device_type)
@eval begin
function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end

CRC.@non_differentiable $op(::Any)
@eval function $(op)(x)
hasmethod(Internal.$(op), Tuple{typeof(x)}) && return Internal.$(op)(x)
return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x))
end
end

Expand All @@ -337,11 +333,3 @@ for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

# Chain Rules Core
function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end
Loading