From c96d3e777601a016c55f56de7e1db6e270dde9af Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 7 Aug 2024 16:42:30 +0000 Subject: [PATCH] Accept axes in Base.checkdims_perm (#55403) Since `checkdims_perm` only checks the axes of the arrays that are passed to it, this PR adds a method that accepts the axes as arguments instead of the arrays. This will avoid having to specialize on array types. An example of an improvement: On master ```julia julia> using LinearAlgebra julia> D = Diagonal(zeros(1)); julia> Dv = Diagonal(view(zeros(1),:)); julia> @time @eval permutedims(D, (2,1)); 0.016841 seconds (13.68 k allocations: 680.672 KiB, 51.37% compilation time) julia> @time @eval permutedims(Dv, (2,1)); 0.009303 seconds (11.24 k allocations: 564.203 KiB, 97.79% compilation time) ``` This PR ```julia julia> @time @eval permutedims(D, (2,1)); 0.016837 seconds (13.42 k allocations: 667.438 KiB, 51.05% compilation time) julia> @time @eval permutedims(Dv, (2,1)); 0.009076 seconds (6.59 k allocations: 321.156 KiB, 97.46% compilation time) ``` The allocations are lower in the second call. I've retained the original method as well, as some packages seem to be using it. This now forwards the axes to the new method. --- base/multidimensional.jl | 9 ++++----- base/permuteddimsarray.jl | 2 +- stdlib/LinearAlgebra/src/bidiag.jl | 2 +- stdlib/LinearAlgebra/src/diagonal.jl | 2 +- stdlib/LinearAlgebra/src/tridiag.jl | 4 ++-- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index bd3641db4999c..5e32a19c2cafb 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -1669,11 +1669,10 @@ function permutedims(B::StridedArray, perm) permutedims!(P, B, perm) end -function checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N} - indsB = axes(B) - length(perm) == N || throw(ArgumentError("expected permutation of size $N, but length(perm)=$(length(perm))")) +checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N} = checkdims_perm(axes(P), axes(B), perm) +function checkdims_perm(indsP::NTuple{N, AbstractUnitRange}, indsB::NTuple{N, AbstractUnitRange}, perm) where {N} + length(perm) == N || throw(ArgumentError(LazyString("expected permutation of size ", N, ", but length(perm)=", length(perm)))) isperm(perm) || throw(ArgumentError("input is not a permutation")) - indsP = axes(P) for i in eachindex(perm) indsP[i] == indsB[perm[i]] || throw(DimensionMismatch("destination tensor of incorrect size")) end @@ -1683,7 +1682,7 @@ end for (V, PT, BT) in Any[((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)] @eval @generated function permutedims!(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) where $(V...) quote - checkdims_perm(P, B, perm) + checkdims_perm(axes(P), axes(B), perm) #calculates all the strides native_strides = size_to_strides(1, size(B)...) diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index 4e77d6b13ce21..cf9748168aac2 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -282,7 +282,7 @@ regions. See also [`permutedims`](@ref). """ function permutedims!(dest, src::AbstractArray, perm) - Base.checkdims_perm(dest, src, perm) + Base.checkdims_perm(axes(dest), axes(src), perm) P = PermutedDimsArray(dest, invperm(perm)) _copy!(P, src) return dest diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 24958422015ab..adb5f8c51bf47 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -287,7 +287,7 @@ adjoint(B::Bidiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U) permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U') function permutedims(B::Bidiagonal, perm) - Base.checkdims_perm(B, B, perm) + Base.checkdims_perm(axes(B), axes(B), perm) NTuple{2}(perm) == (2, 1) ? permutedims(B) : B end function Base.copy(aB::Adjoint{<:Any,<:Bidiagonal}) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 89202e66597f8..77459f7cca520 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -745,7 +745,7 @@ adjoint(D::Diagonal{<:Number}) = Diagonal(vec(adjoint(D.diag))) adjoint(D::Diagonal{<:Number,<:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = Diagonal(adjoint(parent(D.diag))) adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag)) permutedims(D::Diagonal) = D -permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D) +permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D) function diag(D::Diagonal{T}, k::Integer=0) where T # every branch call similar(..., ::Int) to make sure the diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index e217425402df9..c14ed5690198c 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -173,7 +173,7 @@ adjoint(S::SymTridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) permutedims(S::SymTridiagonal) = S function permutedims(S::SymTridiagonal, perm) - Base.checkdims_perm(S, S, perm) + Base.checkdims_perm(axes(S), axes(S), perm) NTuple{2}(perm) == (2, 1) ? permutedims(S) : S end Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...) @@ -639,7 +639,7 @@ adjoint(S::Tridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl) permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl) function permutedims(T::Tridiagonal, perm) - Base.checkdims_perm(T, T, perm) + Base.checkdims_perm(axes(T), axes(T), perm) NTuple{2}(perm) == (2, 1) ? permutedims(T) : T end Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))