diff --git a/base/multidimensional.jl b/base/multidimensional.jl index bd3641db4999c..5e32a19c2cafb 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -1669,11 +1669,10 @@ function permutedims(B::StridedArray, perm) permutedims!(P, B, perm) end -function checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N} - indsB = axes(B) - length(perm) == N || throw(ArgumentError("expected permutation of size $N, but length(perm)=$(length(perm))")) +checkdims_perm(P::AbstractArray{TP,N}, B::AbstractArray{TB,N}, perm) where {TP,TB,N} = checkdims_perm(axes(P), axes(B), perm) +function checkdims_perm(indsP::NTuple{N, AbstractUnitRange}, indsB::NTuple{N, AbstractUnitRange}, perm) where {N} + length(perm) == N || throw(ArgumentError(LazyString("expected permutation of size ", N, ", but length(perm)=", length(perm)))) isperm(perm) || throw(ArgumentError("input is not a permutation")) - indsP = axes(P) for i in eachindex(perm) indsP[i] == indsB[perm[i]] || throw(DimensionMismatch("destination tensor of incorrect size")) end @@ -1683,7 +1682,7 @@ end for (V, PT, BT) in Any[((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)] @eval @generated function permutedims!(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) where $(V...) quote - checkdims_perm(P, B, perm) + checkdims_perm(axes(P), axes(B), perm) #calculates all the strides native_strides = size_to_strides(1, size(B)...) diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index 4e77d6b13ce21..cf9748168aac2 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -282,7 +282,7 @@ regions. See also [`permutedims`](@ref). """ function permutedims!(dest, src::AbstractArray, perm) - Base.checkdims_perm(dest, src, perm) + Base.checkdims_perm(axes(dest), axes(src), perm) P = PermutedDimsArray(dest, invperm(perm)) _copy!(P, src) return dest diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 24958422015ab..adb5f8c51bf47 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -287,7 +287,7 @@ adjoint(B::Bidiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U) permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U') function permutedims(B::Bidiagonal, perm) - Base.checkdims_perm(B, B, perm) + Base.checkdims_perm(axes(B), axes(B), perm) NTuple{2}(perm) == (2, 1) ? permutedims(B) : B end function Base.copy(aB::Adjoint{<:Any,<:Bidiagonal}) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 89202e66597f8..77459f7cca520 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -745,7 +745,7 @@ adjoint(D::Diagonal{<:Number}) = Diagonal(vec(adjoint(D.diag))) adjoint(D::Diagonal{<:Number,<:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = Diagonal(adjoint(parent(D.diag))) adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag)) permutedims(D::Diagonal) = D -permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D) +permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D) function diag(D::Diagonal{T}, k::Integer=0) where T # every branch call similar(..., ::Int) to make sure the diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index e217425402df9..c14ed5690198c 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -173,7 +173,7 @@ adjoint(S::SymTridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) permutedims(S::SymTridiagonal) = S function permutedims(S::SymTridiagonal, perm) - Base.checkdims_perm(S, S, perm) + Base.checkdims_perm(axes(S), axes(S), perm) NTuple{2}(perm) == (2, 1) ? permutedims(S) : S end Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...) @@ -639,7 +639,7 @@ adjoint(S::Tridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl) permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl) function permutedims(T::Tridiagonal, perm) - Base.checkdims_perm(T, T, perm) + Base.checkdims_perm(axes(T), axes(T), perm) NTuple{2}(perm) == (2, 1) ? permutedims(T) : T end Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))