Skip to content

Commit

Permalink
Accept axes in Base.checkdims_perm (#55403)
Browse files Browse the repository at this point in the history
Since `checkdims_perm` only checks the axes of the arrays that are
passed to it, this PR adds a method that accepts the axes as arguments
instead of the arrays. This will avoid having to specialize on array
types.
An example of an improvement:
On master
```julia
julia> using LinearAlgebra

julia> D = Diagonal(zeros(1));

julia> Dv = Diagonal(view(zeros(1),:));

julia> @time @eval permutedims(D, (2,1));
  0.016841 seconds (13.68 k allocations: 680.672 KiB, 51.37% compilation time)

julia> @time @eval permutedims(Dv, (2,1));
  0.009303 seconds (11.24 k allocations: 564.203 KiB, 97.79% compilation time)
```
This PR
```julia
julia> @time @eval permutedims(D, (2,1));
  0.016837 seconds (13.42 k allocations: 667.438 KiB, 51.05% compilation time)

julia> @time @eval permutedims(Dv, (2,1));
  0.009076 seconds (6.59 k allocations: 321.156 KiB, 97.46% compilation time)
```
The allocations are lower in the second call.

I've retained the original method as well, as some packages seem to be
using it. This now forwards the axes to the new method.
  • Loading branch information
jishnub authored Aug 7, 2024
1 parent b43e247 commit bd582f7
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 10 deletions.
9 changes: 4 additions & 5 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1669,11 +1669,10 @@ function permutedims(B::StridedArray, perm)
permutedims!(P, B, perm)
end

function checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N}
indsB = axes(B)
length(perm) == N || throw(ArgumentError("expected permutation of size $N, but length(perm)=$(length(perm))"))
checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N} = checkdims_perm(axes(P), axes(B), perm)
function checkdims_perm(indsP::NTuple{N, AbstractUnitRange}, indsB::NTuple{N, AbstractUnitRange}, perm) where {N}
length(perm) == N || throw(ArgumentError(LazyString("expected permutation of size ", N, ", but length(perm)=", length(perm))))
isperm(perm) || throw(ArgumentError("input is not a permutation"))
indsP = axes(P)
for i in eachindex(perm)
indsP[i] == indsB[perm[i]] || throw(DimensionMismatch("destination tensor of incorrect size"))
end
Expand All @@ -1683,7 +1682,7 @@ end
for (V, PT, BT) in Any[((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)]
@eval @generated function permutedims!(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) where $(V...)
quote
checkdims_perm(P, B, perm)
checkdims_perm(axes(P), axes(B), perm)

#calculates all the strides
native_strides = size_to_strides(1, size(B)...)
Expand Down
2 changes: 1 addition & 1 deletion base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ regions.
See also [`permutedims`](@ref).
"""
function permutedims!(dest, src::AbstractArray, perm)
Base.checkdims_perm(dest, src, perm)
Base.checkdims_perm(axes(dest), axes(src), perm)
P = PermutedDimsArray(dest, invperm(perm))
_copy!(P, src)
return dest
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ adjoint(B::Bidiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U')
function permutedims(B::Bidiagonal, perm)
Base.checkdims_perm(B, B, perm)
Base.checkdims_perm(axes(B), axes(B), perm)
NTuple{2}(perm) == (2, 1) ? permutedims(B) : B
end
function Base.copy(aB::Adjoint{<:Any,<:Bidiagonal})
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ adjoint(D::Diagonal{<:Number}) = Diagonal(vec(adjoint(D.diag)))
adjoint(D::Diagonal{<:Number,<:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = Diagonal(adjoint(parent(D.diag)))
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))
permutedims(D::Diagonal) = D
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D)
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D)

function diag(D::Diagonal{T}, k::Integer=0) where T
# every branch call similar(..., ::Int) to make sure the
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ adjoint(S::SymTridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}})

permutedims(S::SymTridiagonal) = S
function permutedims(S::SymTridiagonal, perm)
Base.checkdims_perm(S, S, perm)
Base.checkdims_perm(axes(S), axes(S), perm)
NTuple{2}(perm) == (2, 1) ? permutedims(S) : S
end
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)
Expand Down Expand Up @@ -639,7 +639,7 @@ adjoint(S::Tridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl)
function permutedims(T::Tridiagonal, perm)
Base.checkdims_perm(T, T, perm)
Base.checkdims_perm(axes(T), axes(T), perm)
NTuple{2}(perm) == (2, 1) ? permutedims(T) : T
end
Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))
Expand Down

0 comments on commit bd582f7

Please sign in to comment.