@@ -1160,30 +1160,32 @@ cat_similar(A::AbstractArray, T, shape) = similar(A, T, shape)
1160
1160
1161
1161
cat_shape (dims, shape:: Tuple ) = shape
1162
1162
@inline cat_shape (dims, shape:: Tuple , nshape:: Tuple , shapes:: Tuple... ) =
1163
- cat_shape (dims, _cshp (dims, (), shape, nshape), shapes... )
1164
-
1165
- _cshp (:: Tuple{} , out, :: Tuple{} , :: Tuple{} ) = out
1166
- _cshp (:: Tuple{} , out, :: Tuple{} , nshape) = (out... , nshape... )
1167
- _cshp (dims, out, :: Tuple{} , :: Tuple{} ) = (out... , map (b -> 1 , dims)... )
1168
- @inline _cshp (dims, out, shape, :: Tuple{} ) =
1169
- _cshp (tail (dims), (out... , shape[1 ] + dims[1 ]), tail (shape), ())
1170
- @inline _cshp (dims, out, :: Tuple{} , nshape) =
1171
- _cshp (tail (dims), (out... , nshape[1 ]), (), tail (nshape))
1172
- @inline function _cshp (:: Tuple{} , out, shape, :: Tuple{} )
1173
- _cs (length (out) + 1 , false , shape[1 ], 1 )
1174
- _cshp ((), (out... , 1 ), tail (shape), ())
1175
- end
1176
- @inline function _cshp (:: Tuple{} , out, shape, nshape)
1177
- next = _cs (length (out) + 1 , false , shape[1 ], nshape[1 ])
1178
- _cshp ((), (out... , next), tail (shape), tail (nshape))
1179
- end
1180
- @inline function _cshp (dims, out, shape, nshape)
1181
- next = _cs (length (out) + 1 , dims[1 ], shape[1 ], nshape[1 ])
1182
- _cshp (tail (dims), (out... , next), tail (shape), tail (nshape))
1183
- end
1184
-
1185
- _cs (d, concat, a, b) = concat ? (a + b) : (a == b ? a : throw (DimensionMismatch (string (
1186
- " mismatch in dimension " , d, " (expected " , a, " got " , b, " )" ))))
1163
+ cat_shape (dims, _cshp (1 , dims, shape, nshape), shapes... )
1164
+
1165
+ _cshp (ndim:: Int , :: Tuple{} , :: Tuple{} , :: Tuple{} ) = ()
1166
+ _cshp (ndim:: Int , :: Tuple{} , :: Tuple{} , nshape) = nshape
1167
+ _cshp (ndim:: Int , dims, :: Tuple{} , :: Tuple{} ) = ntuple (b -> 1 , Val{length (dims)})
1168
+ @inline _cshp (ndim:: Int , dims, shape, :: Tuple{} ) =
1169
+ (shape[1 ] + dims[1 ], _cshp (ndim + 1 , tail (dims), tail (shape), ())... )
1170
+ @inline _cshp (ndim:: Int , dims, :: Tuple{} , nshape) =
1171
+ (nshape[1 ], _cshp (ndim + 1 , tail (dims), (), tail (nshape))... )
1172
+ @inline function _cshp (ndim:: Int , :: Tuple{} , shape, :: Tuple{} )
1173
+ _cs (ndim, shape[1 ], 1 )
1174
+ (1 , _cshp (ndim + 1 , (), tail (shape), ())... )
1175
+ end
1176
+ @inline function _cshp (ndim:: Int , :: Tuple{} , shape, nshape)
1177
+ next = _cs (ndim, shape[1 ], nshape[1 ])
1178
+ (next, _cshp (ndim + 1 , (), tail (shape), tail (nshape))... )
1179
+ end
1180
+ @inline function _cshp (ndim:: Int , dims, shape, nshape)
1181
+ a = shape[1 ]
1182
+ b = nshape[1 ]
1183
+ next = dims[1 ] ? a + b : _cs (ndim, a, b)
1184
+ (next, _cshp (ndim + 1 , tail (dims), tail (shape), tail (nshape))... )
1185
+ end
1186
+
1187
+ _cs (d, a, b) = (a == b ? a : throw (DimensionMismatch (
1188
+ " mismatch in dimension $d (expected $a got $b )" )))
1187
1189
1188
1190
dims2cat {n} (:: Type{Val{n}} ) = ntuple (i -> (i == n), Val{n})
1189
1191
dims2cat (dims) = ntuple (i -> (i in dims), maximum (dims))
0 commit comments