From 2921418396d7e5d40ba2466e5a22147a4f57026d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 15 Nov 2024 17:57:49 -0500 Subject: [PATCH 1/2] fix: circular dep due to Functors --- Project.toml | 5 ++--- ext/ComponentArraysConstructionBaseExt.jl | 7 ------- src/ComponentArrays.jl | 2 +- src/componentarray.jl | 2 ++ 4 files changed, 5 insertions(+), 11 deletions(-) delete mode 100644 ext/ComponentArraysConstructionBaseExt.jl diff --git a/Project.toml b/Project.toml index 01a7c1fc..3419e6e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.18" +version = "0.15.19" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -14,7 +15,6 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -25,7 +25,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ComponentArraysAdaptExt = "Adapt" -ComponentArraysConstructionBaseExt = "ConstructionBase" ComponentArraysGPUArraysExt = "GPUArrays" ComponentArraysOptimisersExt = "Optimisers" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" diff --git a/ext/ComponentArraysConstructionBaseExt.jl b/ext/ComponentArraysConstructionBaseExt.jl deleted file mode 100644 index db6eeb29..00000000 --- a/ext/ComponentArraysConstructionBaseExt.jl +++ /dev/null @@ -1,7 +0,0 @@ -module ComponentArraysConstructionBaseExt - -using ComponentArrays, ConstructionBase - -ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) - -end diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 0bf6e6ec..3dc8b602 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -2,6 +2,7 @@ module ComponentArrays import ChainRulesCore import StaticArrayInterface, ArrayInterface, Functors +import ConstructionBase using LinearAlgebra using StaticArraysCore: StaticArray, SArray, SVector, SMatrix @@ -9,7 +10,6 @@ using StaticArraysCore: StaticArray, SArray, SVector, SMatrix const FlatIdx = Union{Integer, CartesianIndex, CartesianIndices, AbstractArray{<:Integer}} const FlatOrColonIdx = Union{FlatIdx, Colon} - include("utils.jl") export fastindices # Deprecated diff --git a/src/componentarray.jl b/src/componentarray.jl index 00158204..3991edb9 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -89,6 +89,8 @@ ComponentVector{T}(::UndefInitializer, ax) where {T} = ComponentArray{T}(undef, ComponentVector(data::AbstractVector, ax) = ComponentArray(data, ax) ComponentVector(data::AbstractArray, ax) = throw(DimensionMismatch("A `ComponentVector` must be initialized with a 1-dimensional array. This array is $(ndims(data))-dimensional.")) +ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) + # Add new fields to component Vector function ComponentArray(x::ComponentVector; kwargs...) return foldl((x1, kwarg) -> _maybe_add_field(x1, kwarg), (kwargs...,); init=x) From 5f627de6ac60b05f2d1a48de8a57faff66996b37 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 15 Nov 2024 18:00:04 -0500 Subject: [PATCH 2/2] fix: circular dep due to Adapt --- Project.toml | 3 +-- ext/ComponentArraysAdaptExt.jl | 13 ------------- src/ComponentArrays.jl | 1 + src/componentarray.jl | 8 ++++++++ 4 files changed, 10 insertions(+), 15 deletions(-) delete mode 100644 ext/ComponentArraysAdaptExt.jl diff --git a/Project.toml b/Project.toml index 3419e6e3..7fdcdead 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] version = "0.15.19" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -14,7 +15,6 @@ StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -24,7 +24,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -ComponentArraysAdaptExt = "Adapt" ComponentArraysGPUArraysExt = "GPUArrays" ComponentArraysOptimisersExt = "Optimisers" ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools" diff --git a/ext/ComponentArraysAdaptExt.jl b/ext/ComponentArraysAdaptExt.jl deleted file mode 100644 index 8e04c0a9..00000000 --- a/ext/ComponentArraysAdaptExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module ComponentArraysAdaptExt - -using ComponentArrays, Adapt - -function Adapt.adapt_structure(to, x::ComponentArray) - data = adapt(to, getdata(x)) - return ComponentArray(data, getaxes(x)) -end - -Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} = - Adapt.adapt_storage(A, xs) - -end diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 3dc8b602..2f7c1bf8 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -3,6 +3,7 @@ module ComponentArrays import ChainRulesCore import StaticArrayInterface, ArrayInterface, Functors import ConstructionBase +import Adapt using LinearAlgebra using StaticArraysCore: StaticArray, SArray, SVector, SMatrix diff --git a/src/componentarray.jl b/src/componentarray.jl index 3991edb9..ecb0a632 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -58,6 +58,14 @@ function ComponentArray(data, ax::AbstractAxis...) return LazyArray(ComponentArray(x, axs...) for x in part_data) end +function Adapt.adapt_structure(to, x::ComponentArray) + data = Adapt.adapt(to, getdata(x)) + return ComponentArray(data, getaxes(x)) +end + +Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} = + Adapt.adapt_storage(A, xs) + # Entry from NamedTuple, Dict, or kwargs ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...) ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),))