@@ -111,18 +111,25 @@ _throw_reshape_colon_dimmismatch(A, dims) =
111
111
112
112
reshape (parent:: AbstractArray{T,N} , ndims:: Type{Val{N}} ) where {T,N} = parent
113
113
function reshape (parent:: AbstractArray , ndims:: Type{Val{N}} ) where N
114
- reshape (parent, rdims (() , indices (parent), Val{N} ))
114
+ reshape (parent, rdims (Val{N} , indices (parent)))
115
115
end
116
+
116
117
# Move elements from inds to out until out reaches the desired
117
118
# dimensionality N, either filling with OneTo(1) or collapsing the
118
119
# product of trailing dims into the last element
119
- @pure rdims (out:: NTuple{N,Any} , inds:: Tuple{} , :: Type{Val{N}} ) where {N} = out
120
- @pure function rdims (out:: NTuple{N,Any} , inds:: Tuple{Any, Vararg{Any}} , :: Type{Val{N}} ) where N
121
- l = length (last (out)) * prod (map (length, inds))
122
- (front (out)... , OneTo (l))
123
- end
124
- @pure rdims (out:: Tuple , inds:: Tuple{} , :: Type{Val{N}} ) where {N} = rdims ((out... , OneTo (1 )), (), Val{N})
125
- @pure rdims (out:: Tuple , inds:: Tuple{Any, Vararg{Any}} , :: Type{Val{N}} ) where {N} = rdims ((out... , first (inds)), tail (inds), Val{N})
120
+ rdims_trailing (l, inds... ) = length (l) * rdims_trailing (inds... )
121
+ rdims_trailing (l) = length (l)
122
+ rdims (out:: Type{Val{N}} , inds:: Tuple ) where {N} = rdims (ntuple (i -> OneTo (1 ), Val{N}), inds)
123
+ rdims (out:: Tuple{} , inds:: Tuple{} ) = () # N == 0, M == 0
124
+ rdims (out:: Tuple{} , inds:: Tuple{Any} ) = throw (ArgumentError (" new dimensions cannot be empty" )) # N == 0
125
+ rdims (out:: Tuple{} , inds:: NTuple{M,Any} ) where {M} = throw (ArgumentError (" new dimensions cannot be empty" )) # N == 0
126
+ rdims (out:: Tuple{Any} , inds:: Tuple{} ) = out # N == 1, M == 0
127
+ rdims (out:: NTuple{N,Any} , inds:: Tuple{} ) where {N} = out # N > 1, M == 0
128
+ rdims (out:: Tuple{Any} , inds:: Tuple{Any} ) = inds # N == 1, M == 1
129
+ rdims (out:: Tuple{Any} , inds:: NTuple{M,Any} ) where {M} = (OneTo (rdims_trailing (inds... )),) # N == 1, M > 1
130
+ rdims (out:: NTuple{N,Any} , inds:: NTuple{N,Any} ) where {N} = inds # N > 1, M == N
131
+ rdims (out:: NTuple{N,Any} , inds:: NTuple{M,Any} ) where {N,M} = (first (inds), rdims (tail (out), tail (inds))... ) # N > 1, M > 1, M != N
132
+
126
133
127
134
# _reshape on Array returns an Array
128
135
_reshape (parent:: Vector , dims:: Dims{1} ) = parent
0 commit comments