Skip to content

Commit c96d3e7

Browse files
jishnublazarusA
authored andcommitted
Accept axes in Base.checkdims_perm (JuliaLang#55403)
Since `checkdims_perm` only checks the axes of the arrays that are passed to it, this PR adds a method that accepts the axes as arguments instead of the arrays. This will avoid having to specialize on array types. An example of an improvement: On master ```julia julia> using LinearAlgebra julia> D = Diagonal(zeros(1)); julia> Dv = Diagonal(view(zeros(1),:)); julia> @time @eval permutedims(D, (2,1)); 0.016841 seconds (13.68 k allocations: 680.672 KiB, 51.37% compilation time) julia> @time @eval permutedims(Dv, (2,1)); 0.009303 seconds (11.24 k allocations: 564.203 KiB, 97.79% compilation time) ``` This PR ```julia julia> @time @eval permutedims(D, (2,1)); 0.016837 seconds (13.42 k allocations: 667.438 KiB, 51.05% compilation time) julia> @time @eval permutedims(Dv, (2,1)); 0.009076 seconds (6.59 k allocations: 321.156 KiB, 97.46% compilation time) ``` The allocations are lower in the second call. I've retained the original method as well, as some packages seem to be using it. This now forwards the axes to the new method.
1 parent 1fbf822 commit c96d3e7

File tree

5 files changed

+9
-10
lines changed

5 files changed

+9
-10
lines changed

base/multidimensional.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,11 +1669,10 @@ function permutedims(B::StridedArray, perm)
16691669
permutedims!(P, B, perm)
16701670
end
16711671

1672-
function checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N}
1673-
indsB = axes(B)
1674-
length(perm) == N || throw(ArgumentError("expected permutation of size $N, but length(perm)=$(length(perm))"))
1672+
checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N} = checkdims_perm(axes(P), axes(B), perm)
1673+
function checkdims_perm(indsP::NTuple{N, AbstractUnitRange}, indsB::NTuple{N, AbstractUnitRange}, perm) where {N}
1674+
length(perm) == N || throw(ArgumentError(LazyString("expected permutation of size ", N, ", but length(perm)=", length(perm))))
16751675
isperm(perm) || throw(ArgumentError("input is not a permutation"))
1676-
indsP = axes(P)
16771676
for i in eachindex(perm)
16781677
indsP[i] == indsB[perm[i]] || throw(DimensionMismatch("destination tensor of incorrect size"))
16791678
end
@@ -1683,7 +1682,7 @@ end
16831682
for (V, PT, BT) in Any[((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)]
16841683
@eval @generated function permutedims!(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) where $(V...)
16851684
quote
1686-
checkdims_perm(P, B, perm)
1685+
checkdims_perm(axes(P), axes(B), perm)
16871686

16881687
#calculates all the strides
16891688
native_strides = size_to_strides(1, size(B)...)

base/permuteddimsarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ regions.
282282
See also [`permutedims`](@ref).
283283
"""
284284
function permutedims!(dest, src::AbstractArray, perm)
285-
Base.checkdims_perm(dest, src, perm)
285+
Base.checkdims_perm(axes(dest), axes(src), perm)
286286
P = PermutedDimsArray(dest, invperm(perm))
287287
_copy!(P, src)
288288
return dest

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ adjoint(B::Bidiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
287287
transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
288288
permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U')
289289
function permutedims(B::Bidiagonal, perm)
290-
Base.checkdims_perm(B, B, perm)
290+
Base.checkdims_perm(axes(B), axes(B), perm)
291291
NTuple{2}(perm) == (2, 1) ? permutedims(B) : B
292292
end
293293
function Base.copy(aB::Adjoint{<:Any,<:Bidiagonal})

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ adjoint(D::Diagonal{<:Number}) = Diagonal(vec(adjoint(D.diag)))
745745
adjoint(D::Diagonal{<:Number,<:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = Diagonal(adjoint(parent(D.diag)))
746746
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))
747747
permutedims(D::Diagonal) = D
748-
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D)
748+
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D)
749749

750750
function diag(D::Diagonal{T}, k::Integer=0) where T
751751
# every branch call similar(..., ::Int) to make sure the

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ adjoint(S::SymTridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}})
173173

174174
permutedims(S::SymTridiagonal) = S
175175
function permutedims(S::SymTridiagonal, perm)
176-
Base.checkdims_perm(S, S, perm)
176+
Base.checkdims_perm(axes(S), axes(S), perm)
177177
NTuple{2}(perm) == (2, 1) ? permutedims(S) : S
178178
end
179179
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)
@@ -639,7 +639,7 @@ adjoint(S::Tridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
639639
transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
640640
permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl)
641641
function permutedims(T::Tridiagonal, perm)
642-
Base.checkdims_perm(T, T, perm)
642+
Base.checkdims_perm(axes(T), axes(T), perm)
643643
NTuple{2}(perm) == (2, 1) ? permutedims(T) : T
644644
end
645645
Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))

0 commit comments

Comments
 (0)