@@ -140,88 +140,122 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
140
140
141
141
end # join
142
142
143
- function greatest_common_axis (As:: AxisArray... )
144
- length (As) == 1 && return ndims (first (As))
143
+ function _flatten_array_axes (array_name, array_axes... )
144
+ ((array_name, (idx isa Tuple ? idx : (idx,)). .. ) for idx in product ((Ax. val for Ax in array_axes). .. ))
145
+ end
145
146
146
- for (i, zip_axes) in enumerate (zip (axes .(As)... ))
147
- if ! all (ax -> ax == zip_axes[1 ], zip_axes[2 : end ])
148
- return i - 1
149
- end
147
+ function _flatten_axes (array_names, array_axes)
148
+ collect (Iterators. flatten (map (array_names, array_axes) do tup_name, tup_array_axes
149
+ _flatten_array_axes (tup_name, tup_array_axes... )
150
+ end ))
151
+ end
152
+
153
+ function _splitall {N} (:: Type{Val{N}} , As... )
154
+ tuple ((Base. IteratorsMD. split (A, Val{N}) for A in As). .. )
155
+ end
156
+
157
+ function _reshapeall {N} (:: Type{Val{N}} , As... )
158
+ tuple ((reshape (A, Val{N}) for A in As). .. )
159
+ end
160
+
161
+ function _check_common_axes (common_axis_tuple)
162
+ if ! all (axisname (first (common_axis_tuple)) .=== axisname .(common_axis_tuple[2 : end ]))
163
+ throw (ArgumentError (" Leading common axes must have the same name in each array" ))
150
164
end
151
165
152
- return minimum ( map (ndims, As))
166
+ return nothing
153
167
end
154
168
155
- function flatten_array_axes (array_name, array_axes)
156
- map (zip (repeated (array_name), product (map (Ax-> Ax. val, array_axes)... ))) do tup
157
- tup_name, tup_idx = tup
158
- return (tup_name, tup_idx... )
169
+ function _flat_axis_eltype (LType, trailing_axes)
170
+ eltypes = map (trailing_axes) do array_trailing_axes
171
+ Tuple{LType, eltype .(array_trailing_axes)... }
159
172
end
173
+
174
+ return typejoin (eltypes... )
160
175
end
161
176
162
- function flatten_axes (array_names, array_axes )
163
- collect ( chain ( map (flatten_array_axes, array_names, array_axes) ... ) )
177
+ function flatten {N, NA} ( :: Type{Val{N}} , As :: Vararg{AxisArray, NA} )
178
+ flatten (Val{N}, ntuple (identity, Val{NA}), As ... )
164
179
end
165
180
166
181
"""
167
182
flatten(As::AxisArray...) -> AxisArray
168
- flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
183
+ flatten(last_dim::Type{Val{N}}, As::AxisArray...) -> AxisArray
184
+ flatten(last_dim::Type{Val{N}}, labels::Tuple, As::AxisArray...) -> AxisArray
169
185
170
- Concatenates AxisArrays with equal leading axes into a single AxisArray.
186
+ Concatenates AxisArrays with N equal leading axes into a single AxisArray.
171
187
All additional axes in any of the arrays are flattened into a single additional
172
188
CategoricalVector{Tuple} axis.
173
189
174
190
### Arguments
175
191
176
- * `last_dim::Integer `: (optional) the greatest common dimension to share between all input
177
- arrays. The remaining axes are flattened. If this argument is not
178
- provided, the greatest common axis found among the input arrays is
179
- used. All preceeding axes must also be common to each input array, at
180
- the same dimension. Values from 0 up to one more than the minimum
181
- number of dimensions across all input arrays are allowed.
192
+ * `::Type{Val{N}} `: the greatest common dimension to share between all input
193
+ arrays. The remaining axes are flattened. All N axes must be common
194
+ to each input array, at the same dimension. Values from 0 up to the
195
+ minimum number of dimensions across all input arrays are allowed.
196
+ * `labels::Tuple`: (optional) a label for each AxisArray in As which is used in the flat
197
+ axis
182
198
* `As::AxisArray...`: AxisArrays to be flattened together.
183
199
"""
184
- function flatten (As:: AxisArray... ; kwargs... )
185
- gca = greatest_common_axis (As... )
186
-
187
- return _flatten (gca, As... ; kwargs... )
188
- end
189
-
190
- function flatten (last_dim:: Integer , As:: AxisArray... ; kwargs... )
191
- last_dim >= 0 || throw (ArgumentError (" last_dim must be at least 0" ))
192
-
193
- if last_dim > minimum (map (ndims, As))
194
- throw (ArgumentError (
195
- " There must be at least $last_dim (last_dim) axes in each argument"
196
- ))
200
+ @generated function flatten {N, AN, LType} (:: Type{Val{N}} , labels:: NTuple{AN, LType} , As:: Vararg{AxisArray, AN} )
201
+ if N < 0
202
+ throw (ArgumentError (" flatten dimension N must be at least 0" ))
197
203
end
198
204
199
- if last_dim > greatest_common_axis (As ... )
205
+ if N > minimum ( ndims .(As) )
200
206
throw (ArgumentError (
201
- " The first $last_dim axes don't all match across all arguments"
207
+ """
208
+ flatten dimension N must not be greater than the maximum number of dimensions
209
+ across all input arrays
210
+ """
202
211
))
203
212
end
204
213
205
- return _flatten (last_dim, As ... ; kwargs ... )
206
- end
214
+ flat_dim = Val{N + 1 }
215
+ flat_dim_int = Int (N) + 1
207
216
208
- function _flatten (
209
- last_dim:: Integer ,
210
- As:: AxisArray... ;
211
- array_names= 1 : length (As),
212
- axis_name= nothing ,
213
- )
214
- common_axes = axes (As[1 ])[1 : last_dim]
215
-
216
- if axis_name === nothing
217
- axis_name = _defaultdimname (last_dim + 1 )
218
- elseif ! isa (axis_name, Symbol)
219
- throw (ArgumentError (" axis_name must be a Symbol" ))
220
- end
217
+ common_axes, trailing_axes = zip (_splitall (Val{N}, axisparams .(As)... )... )
218
+
219
+ foreach (_check_common_axes, zip (common_axes... ))
220
+
221
+ new_common_axes = first (common_axes)
222
+ flat_axis_eltype = _flat_axis_eltype (LType, trailing_axes)
223
+ flat_axis_type = CategoricalVector{flat_axis_eltype, Vector{flat_axis_eltype}}
224
+
225
+ new_axes_type = Tuple{new_common_axes... , Axis{:flat , flat_axis_type}}
226
+ new_eltype = Base. promote_eltype (As... )
221
227
222
- new_data = cat (last_dim + 1 , ( view (A . data, repeated (:, last_dim + 1 ) ... ) for A in As) . .. )
223
- new_axis = flatten_axes (array_names, map (A -> axes (A)[last_dim + 1 : end ], As))
228
+ quote
229
+ common_axes, trailing_axes = zip ( _splitall (Val{N}, axes .( As)... ) ... )
224
230
225
- # TODO : Consider creating a SortedVector axis when all flattened axes are Dimensional
226
- return AxisArray (new_data, common_axes... , CategoricalVector (new_axis))
231
+ for common_axis_tuple in zip (common_axes... )
232
+ if ! isempty (common_axis_tuple)
233
+ for common_axis in common_axis_tuple[2 : end ]
234
+ if ! all (axisvalues (common_axis) .== axisvalues (common_axis_tuple[1 ]))
235
+ throw (ArgumentError (
236
+ """
237
+ Leading common axes must be identical across
238
+ all input arrays"""
239
+ ))
240
+ end
241
+ end
242
+ end
243
+ end
244
+
245
+ array_data = cat ($ flat_dim, _reshapeall ($ flat_dim, As... )... )
246
+
247
+ axis_array_type = AxisArray{
248
+ $ new_eltype,
249
+ $ flat_dim_int,
250
+ Array{$ new_eltype, $ flat_dim_int},
251
+ $ new_axes_type
252
+ }
253
+
254
+ new_axes = (
255
+ first (common_axes)... ,
256
+ Axis {:flat, $flat_axis_type} ($ flat_axis_type (_flatten_axes (labels, trailing_axes))),
257
+ )
258
+
259
+ return axis_array_type (array_data, new_axes)
260
+ end
227
261
end
0 commit comments