From 501544537b7aaba8d0970dd4bc7100dba3e947ce Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 27 Dec 2023 21:21:25 +0530 Subject: [PATCH] Resolve ambiguities in https://github.com/jonniedie/ComponentArrays.jl/pull/231 (#230) --- .github/workflows/ci.yml | 1 + src/array_interface.jl | 18 +++++++++--------- src/componentarray.jl | 12 ++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0ce8ddc..ed8067a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,7 @@ jobs: - '1.8' - '1.9' - '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia. + - '1.10.0-beta3' os: - ubuntu-latest arch: diff --git a/src/array_interface.jl b/src/array_interface.jl index 7e6c6249..c9060eb4 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -33,9 +33,9 @@ second_axis(::ComponentVector) = FlatAxis() # Are all these methods necessary? # TODO: See what we can reduce down to without getting ambiguity errors -Base.vcat(x::ComponentVector, y::AbstractVector) = vcat(getdata(x), y) -Base.vcat(x::AbstractVector, y::ComponentVector) = vcat(x, getdata(y)) -function Base.vcat(x::ComponentVector, y::ComponentVector) +Base.vcat(x::ComponentVector{<:Number}, y::AbstractVector{<:Number}) = vcat(getdata(x), y) +Base.vcat(x::AbstractVector{<:Number}, y::ComponentVector{<:Number}) = vcat(x, getdata(y)) +function Base.vcat(x::ComponentVector{<:Number}, y::ComponentVector{<:Number}) if reduce((accum, key) -> accum || (key in keys(x)), keys(y); init=false) return vcat(getdata(x), getdata(y)) else @@ -46,7 +46,7 @@ function Base.vcat(x::ComponentVector, y::ComponentVector) return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...))) end end -function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) +function Base.vcat(x::AbstractComponentVecOrMat{<:Number}, y::AbstractComponentVecOrMat{<:Number}) ax_x, ax_y = getindex.(getaxes.((x, y)), 1) if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[2:end] != getaxes(y)[2:end] return vcat(getdata(x), getdata(y)) @@ -57,10 +57,10 @@ function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...) end end -Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1])) -Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...) -Base.vcat(x::ComponentVector, args::Vararg{Union{Number, UniformScaling, AbstractVecOrMat}}) = vcat(getdata(x), getdata.(args)...) -Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...) +Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray{<:Number}} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1])) +Base.vcat(x::ComponentVector{<:Number}, args...) = vcat(getdata(x), getdata.(args)...) +Base.vcat(x::ComponentVector{<:Number}, args::Vararg{Union{Number, UniformScaling, AbstractVecOrMat{<:Number}}}) = vcat(getdata(x), getdata.(args)...) +Base.vcat(x::ComponentVector{<:Number}, args::Vararg{AbstractVector{T}, N}) where {T<:Number,N} = vcat(getdata(x), getdata.(args)...) function Base.hvcat(row_lengths::NTuple{N,Int}, xs::Vararg{AbstractComponentVecOrMat}) where {N} i = 1 @@ -147,4 +147,4 @@ end Base.stride(x::ComponentArray, k) = stride(getdata(x), k) Base.stride(x::ComponentArray, k::Int64) = stride(getdata(x), k) -ArrayInterface.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A \ No newline at end of file +ArrayInterface.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A diff --git a/src/componentarray.jl b/src/componentarray.jl index 00ed6064..736b0c44 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -124,12 +124,12 @@ const AdjOrTransComponentArray{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} whe const AdjOrTransComponentVector{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentVector const AdjOrTransComponentMatrix{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentMatrix -const ComponentVecOrMat = Union{ComponentVector, ComponentMatrix} -const AdjOrTransComponentVecOrMat = AdjOrTrans{T, <:ComponentVecOrMat} where T -const AbstractComponentArray = Union{ComponentArray, AdjOrTransComponentArray} -const AbstractComponentVecOrMat = Union{ComponentVecOrMat, AdjOrTransComponentVecOrMat} -const AbstractComponentVector = Union{ComponentVector, AdjOrTransComponentVector} -const AbstractComponentMatrix = Union{ComponentMatrix, AdjOrTransComponentMatrix} +const ComponentVecOrMat{T} = Union{ComponentVector{T}, ComponentMatrix{T}} where{T} +const AdjOrTransComponentVecOrMat{T} = AdjOrTrans{T, <:ComponentVecOrMat} where {T} +const AbstractComponentArray{T} = Union{ComponentArray{T}, AdjOrTransComponentArray{T}} where{T} +const AbstractComponentVecOrMat{T} = Union{ComponentVecOrMat{T}, AdjOrTransComponentVecOrMat{T}} where{T} +const AbstractComponentVector{T} = Union{ComponentVector{T}, AdjOrTransComponentVector{T}} where{T} +const AbstractComponentMatrix{T} = Union{ComponentMatrix{T}, AdjOrTransComponentMatrix{T}} where{T} ## Constructor helpers