Skip to content

Commit 0bc92de

Browse files
committed
Add flatten
1 parent 46982e2 commit 0bc92de

File tree

4 files changed

+106
-1
lines changed

4 files changed

+106
-1
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
julia 0.5
22
IntervalSets
3+
Iterators
34
RangeArrays
45
Compat 0.19

src/AxisArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ module AxisArrays
44

55
using Base: tail
66
using RangeArrays, IntervalSets
7+
using Iterators
78
using Compat
89

9-
export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, atvalue
10+
export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, atvalue, flatten
1011

1112
# From IntervalSets:
1213
export ClosedInterval, ..

src/combine.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,89 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
139139
return result
140140

141141
end #join
142+
143+
function greatest_common_axis(As::AxisArray...)
144+
length(As) == 1 && return ndims(first(As))
145+
146+
for (i, zip_axes) in enumerate(zip(axes.(As)...))
147+
if !all(ax -> ax == zip_axes[1], zip_axes[2:end])
148+
return i - 1
149+
end
150+
end
151+
152+
return minimum(map(ndims, As))
153+
end
154+
155+
function flatten_array_axes(array_name, array_axes)
156+
map(zip(repeated(array_name), product(map(Ax->Ax.val, array_axes)...))) do tup
157+
tup_name, tup_idx = tup
158+
return (tup_name, tup_idx...)
159+
end
160+
end
161+
162+
function flatten_axes(array_names, array_axes)
163+
collect(chain(map(flatten_array_axes, array_names, array_axes)...))
164+
end
165+
166+
"""
167+
flatten(As::AxisArray...) -> AxisArray
168+
flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
169+
170+
Concatenates AxisArrays with equal leading axes into a single AxisArray.
171+
All additional axes in any of the arrays are flattened into a single additional
172+
CategoricalVector{Tuple} axis.
173+
174+
### Arguments
175+
176+
* `last_dim::Integer`: (optional) the greatest common dimension to share between all input
177+
arrays. The remaining axes are flattened. If this argument is not
178+
provided, the greatest common axis found among the input arrays is
179+
used. All preceeding axes must also be common to each input array, at
180+
the same dimension. Values from 0 up to one more than the minimum
181+
number of dimensions across all input arrays are allowed.
182+
* `As::AxisArray...`: AxisArrays to be flattened together.
183+
"""
184+
function flatten(As::AxisArray...; kwargs...)
185+
gca = greatest_common_axis(As...)
186+
187+
return _flatten(gca, As...; kwargs...)
188+
end
189+
190+
function flatten(last_dim::Integer, As::AxisArray...; kwargs...)
191+
last_dim >= 0 || throw(ArgumentError("last_dim must be at least 0"))
192+
193+
if last_dim > minimum(map(ndims, As))
194+
throw(ArgumentError(
195+
"There must be at least $last_dim (last_dim) axes in each argument"
196+
))
197+
end
198+
199+
if last_dim > greatest_common_axis(As...)
200+
throw(ArgumentError(
201+
"The first $last_dim axes don't all match across all arguments"
202+
))
203+
end
204+
205+
return _flatten(last_dim, As...; kwargs...)
206+
end
207+
208+
function _flatten(
209+
last_dim::Integer,
210+
As::AxisArray...;
211+
array_names=1:length(As),
212+
axis_name=nothing,
213+
)
214+
common_axes = axes(As[1])[1:last_dim]
215+
216+
if axis_name === nothing
217+
axis_name = _defaultdimname(last_dim + 1)
218+
elseif !isa(axis_name, Symbol)
219+
throw(ArgumentError("axis_name must be a Symbol"))
220+
end
221+
222+
new_data = cat(last_dim + 1, (view(A.data, repeated(:, last_dim + 1)...) for A in As)...)
223+
new_axis = flatten_axes(array_names, map(A -> axes(A)[last_dim+1:end], As))
224+
225+
# TODO: Consider creating a SortedVector axis when all flattened axes are Dimensional
226+
return AxisArray(new_data, common_axes..., CategoricalVector(new_axis))
227+
end

test/combine.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,20 @@ ABdata[3:6,3:6,:,2] = Bdata
4747
@test join(A,B,method=:left) == AxisArray(ABdata[1:4, 1:4, :, :], A.axes...)
4848
@test join(A,B,method=:right) == AxisArray(ABdata[3:6, 3:6, :, :], B.axes...)
4949
@test join(A,B,method=:outer) == join(A,B)
50+
51+
# flatten
52+
A1 = AxisArray(A1data, Axis{:X}(1:2), Axis{:Y}(1:2))
53+
A2 = AxisArray(reshape(A2data, size(A2data)..., 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:Z}([:foo]))
54+
55+
@test flatten(A1, A2; array_names=[:A1, :A2]) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:page}(CategoricalVector([(:A1,), (:A2, :foo)])))
56+
@test flatten(A1; array_names=[:foo]) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:page}(CategoricalVector([(:foo,)])))
57+
@test flatten(A1; array_names=[:a], axis_name=:ax) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:ax}(CategoricalVector([(:a,)])))
58+
59+
@test_throws ArgumentError flatten(-1, A1)
60+
@test_throws ArgumentError flatten(10, A1)
61+
62+
A1ᵀ = transpose(A1)
63+
@test flatten(A1, A1ᵀ) == flatten(0, A1, A1ᵀ)
64+
@test_throws ArgumentError flatten(-1, A1, A1ᵀ)
65+
@test_throws ArgumentError flatten(1, A1, A1ᵀ)
66+
@test_throws ArgumentError flatten(10, A1, A1ᵀ)

0 commit comments

Comments
 (0)