@@ -139,3 +139,89 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
139
139
return result
140
140
141
141
end # join
142
+
143
+ function greatest_common_axis (As:: AxisArray... )
144
+ length (As) == 1 && return ndims (first (As))
145
+
146
+ for (i, zip_axes) in enumerate (zip (axes .(As)... ))
147
+ if ! all (ax -> ax == zip_axes[1 ], zip_axes[2 : end ])
148
+ return i - 1
149
+ end
150
+ end
151
+
152
+ return minimum (map (ndims, As))
153
+ end
154
+
155
+ function flatten_array_axes (array_name, array_axes)
156
+ map (zip (repeated (array_name), product (map (Ax-> Ax. val, array_axes)... ))) do tup
157
+ tup_name, tup_idx = tup
158
+ return (tup_name, tup_idx... )
159
+ end
160
+ end
161
+
162
+ function flatten_axes (array_names, array_axes)
163
+ collect (chain (map (flatten_array_axes, array_names, array_axes)... ))
164
+ end
165
+
166
+ """
167
+ flatten(As::AxisArray...) -> AxisArray
168
+ flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
169
+
170
+ Concatenates AxisArrays with equal leading axes into a single AxisArray.
171
+ All additional axes in any of the arrays are flattened into a single additional
172
+ CategoricalVector{Tuple} axis.
173
+
174
+ ### Arguments
175
+
176
+ * `last_dim::Integer`: (optional) the greatest common dimension to share between all input
177
+ arrays. The remaining axes are flattened. If this argument is not
178
+ provided, the greatest common axis found among the input arrays is
179
+ used. All preceeding axes must also be common to each input array, at
180
+ the same dimension. Values from 0 up to one more than the minimum
181
+ number of dimensions across all input arrays are allowed.
182
+ * `As::AxisArray...`: AxisArrays to be flattened together.
183
+ """
184
+ function flatten (As:: AxisArray... ; kwargs... )
185
+ gca = greatest_common_axis (As... )
186
+
187
+ return _flatten (gca, As... ; kwargs... )
188
+ end
189
+
190
+ function flatten (last_dim:: Integer , As:: AxisArray... ; kwargs... )
191
+ last_dim >= 0 || throw (ArgumentError (" last_dim must be at least 0" ))
192
+
193
+ if last_dim > minimum (map (ndims, As))
194
+ throw (ArgumentError (
195
+ " There must be at least $last_dim (last_dim) axes in each argument"
196
+ ))
197
+ end
198
+
199
+ if last_dim > greatest_common_axis (As... )
200
+ throw (ArgumentError (
201
+ " The first $last_dim axes don't all match across all arguments"
202
+ ))
203
+ end
204
+
205
+ return _flatten (last_dim, As... ; kwargs... )
206
+ end
207
+
208
+ function _flatten (
209
+ last_dim:: Integer ,
210
+ As:: AxisArray... ;
211
+ array_names= 1 : length (As),
212
+ axis_name= nothing ,
213
+ )
214
+ common_axes = axes (As[1 ])[1 : last_dim]
215
+
216
+ if axis_name === nothing
217
+ axis_name = _defaultdimname (last_dim + 1 )
218
+ elseif ! isa (axis_name, Symbol)
219
+ throw (ArgumentError (" axis_name must be a Symbol" ))
220
+ end
221
+
222
+ new_data = cat (last_dim + 1 , (view (A. data, repeated (:, last_dim + 1 )... ) for A in As). .. )
223
+ new_axis = flatten_axes (array_names, map (A -> axes (A)[last_dim+ 1 : end ], As))
224
+
225
+ # TODO : Consider creating a SortedVector axis when all flattened axes are Dimensional
226
+ return AxisArray (new_data, common_axes... , CategoricalVector (new_axis))
227
+ end
0 commit comments