Skip to content

Commit 3c29df0

Browse files
committed
Reduce to flatten{N} only
1 parent 5958549 commit 3c29df0

File tree

6 files changed

+125
-70
lines changed

6 files changed

+125
-70
lines changed

src/AxisArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ __precompile__()
33
module AxisArrays
44

55
using Base: tail
6+
import Base.Iterators: repeated
67
using RangeArrays, IntervalSets
78
using IterTools
89
using Compat

src/categoricalvector.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,12 @@ A[(:a,:x), :]
4242
A[(:a,:x,:x), :]
4343
```
4444
"""
45-
immutable CategoricalVector{T} <: AbstractVector{T}
46-
data::AbstractVector{T}
45+
immutable CategoricalVector{T, A<:AbstractVector{T}} <: AbstractVector{T}
46+
data::A
47+
end
48+
49+
function CategoricalVector(data::AbstractVector{T}) where T
50+
CategoricalVector{T, typeof(data)}(data)
4751
end
4852

4953
Base.getindex(v::CategoricalVector, idx::Int) = v.data[idx]
@@ -54,16 +58,16 @@ Base.size(v::CategoricalVector) = size(v.data)
5458
Base.size(v::CategoricalVector, i) = size(v.data, i)
5559
Base.indices(v::CategoricalVector) = indices(v.data)
5660

57-
axistrait{T}(::Type{CategoricalVector{T}}) = Categorical
61+
axistrait(::Type{CategoricalVector{T,A}}) where {T,A} = Categorical
5862
checkaxis(::CategoricalVector) = nothing
5963

6064

6165
## Add some special indexing for CategoricalVector{Tuple}'s to achieve something like
6266
## Panda's hierarchical indexing
6367

64-
axisindexes{T<:Tuple,S}(ax::Axis{S,CategoricalVector{T}}, idx) = axisindexes(ax, (idx,))
68+
axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx) = axisindexes(ax, (idx,))
6569

66-
function axisindexes{T<:Tuple,S}(ax::Axis{S,CategoricalVector{T}}, idx::Tuple)
70+
function axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx::Tuple)
6771
collect(filter(ax_idx->_tuple_matches(ax.val[ax_idx], idx), indices(ax.val)...))
6872
end
6973

@@ -77,5 +81,5 @@ function _tuple_matches(element::Tuple, idx::Tuple)
7781
return true
7882
end
7983

80-
axisindexes{T<:Tuple,S}(ax::Axis{S,CategoricalVector{T}}, idx::AbstractArray) =
84+
axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx::AbstractArray) =
8185
vcat([axisindexes(ax, i) for i in idx]...)

src/combine.jl

+89-55
Original file line numberDiff line numberDiff line change
@@ -140,88 +140,122 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
140140

141141
end #join
142142

143-
function greatest_common_axis(As::AxisArray...)
144-
length(As) == 1 && return ndims(first(As))
143+
function _flatten_array_axes(array_name, array_axes...)
144+
((array_name, (idx isa Tuple ? idx : (idx,))...) for idx in product((Ax.val for Ax in array_axes)...))
145+
end
145146

146-
for (i, zip_axes) in enumerate(zip(axes.(As)...))
147-
if !all(ax -> ax == zip_axes[1], zip_axes[2:end])
148-
return i - 1
149-
end
147+
function _flatten_axes(array_names, array_axes)
148+
collect(Iterators.flatten(map(array_names, array_axes) do tup_name, tup_array_axes
149+
_flatten_array_axes(tup_name, tup_array_axes...)
150+
end))
151+
end
152+
153+
function _splitall{N}(::Type{Val{N}}, As...)
154+
tuple((Base.IteratorsMD.split(A, Val{N}) for A in As)...)
155+
end
156+
157+
function _reshapeall{N}(::Type{Val{N}}, As...)
158+
tuple((reshape(A, Val{N}) for A in As)...)
159+
end
160+
161+
function _check_common_axes(common_axis_tuple)
162+
if !all(axisname(first(common_axis_tuple)) .=== axisname.(common_axis_tuple[2:end]))
163+
throw(ArgumentError("Leading common axes must have the same name in each array"))
150164
end
151165

152-
return minimum(map(ndims, As))
166+
return nothing
153167
end
154168

155-
function flatten_array_axes(array_name, array_axes)
156-
map(zip(repeated(array_name), product(map(Ax->Ax.val, array_axes)...))) do tup
157-
tup_name, tup_idx = tup
158-
return (tup_name, tup_idx...)
169+
function _flat_axis_eltype(LType, trailing_axes)
170+
eltypes = map(trailing_axes) do array_trailing_axes
171+
Tuple{LType, eltype.(array_trailing_axes)...}
159172
end
173+
174+
return typejoin(eltypes...)
160175
end
161176

162-
function flatten_axes(array_names, array_axes)
163-
collect(chain(map(flatten_array_axes, array_names, array_axes)...))
177+
function flatten{N, NA}(::Type{Val{N}}, As::Vararg{AxisArray, NA})
178+
flatten(Val{N}, ntuple(identity, Val{NA}), As...)
164179
end
165180

166181
"""
167182
flatten(As::AxisArray...) -> AxisArray
168-
flatten(last_dim::Integer, As::AxisArray...) -> AxisArray
183+
flatten(last_dim::Type{Val{N}}, As::AxisArray...) -> AxisArray
184+
flatten(last_dim::Type{Val{N}}, labels::Tuple, As::AxisArray...) -> AxisArray
169185
170-
Concatenates AxisArrays with equal leading axes into a single AxisArray.
186+
Concatenates AxisArrays with N equal leading axes into a single AxisArray.
171187
All additional axes in any of the arrays are flattened into a single additional
172188
CategoricalVector{Tuple} axis.
173189
174190
### Arguments
175191
176-
* `last_dim::Integer`: (optional) the greatest common dimension to share between all input
177-
arrays. The remaining axes are flattened. If this argument is not
178-
provided, the greatest common axis found among the input arrays is
179-
used. All preceeding axes must also be common to each input array, at
180-
the same dimension. Values from 0 up to one more than the minimum
181-
number of dimensions across all input arrays are allowed.
192+
* `::Type{Val{N}}`: the greatest common dimension to share between all input
193+
arrays. The remaining axes are flattened. All N axes must be common
194+
to each input array, at the same dimension. Values from 0 up to the
195+
minimum number of dimensions across all input arrays are allowed.
196+
* `labels::Tuple`: (optional) a label for each AxisArray in As which is used in the flat
197+
axis
182198
* `As::AxisArray...`: AxisArrays to be flattened together.
183199
"""
184-
function flatten(As::AxisArray...; kwargs...)
185-
gca = greatest_common_axis(As...)
186-
187-
return _flatten(gca, As...; kwargs...)
188-
end
189-
190-
function flatten(last_dim::Integer, As::AxisArray...; kwargs...)
191-
last_dim >= 0 || throw(ArgumentError("last_dim must be at least 0"))
192-
193-
if last_dim > minimum(map(ndims, As))
194-
throw(ArgumentError(
195-
"There must be at least $last_dim (last_dim) axes in each argument"
196-
))
200+
@generated function flatten{N, AN, LType}(::Type{Val{N}}, labels::NTuple{AN, LType}, As::Vararg{AxisArray, AN})
201+
if N < 0
202+
throw(ArgumentError("flatten dimension N must be at least 0"))
197203
end
198204

199-
if last_dim > greatest_common_axis(As...)
205+
if N > minimum(ndims.(As))
200206
throw(ArgumentError(
201-
"The first $last_dim axes don't all match across all arguments"
207+
"""
208+
flatten dimension N must not be greater than the maximum number of dimensions
209+
across all input arrays
210+
"""
202211
))
203212
end
204213

205-
return _flatten(last_dim, As...; kwargs...)
206-
end
214+
flat_dim = Val{N + 1}
215+
flat_dim_int = Int(N) + 1
207216

208-
function _flatten(
209-
last_dim::Integer,
210-
As::AxisArray...;
211-
array_names=1:length(As),
212-
axis_name=nothing,
213-
)
214-
common_axes = axes(As[1])[1:last_dim]
215-
216-
if axis_name === nothing
217-
axis_name = _defaultdimname(last_dim + 1)
218-
elseif !isa(axis_name, Symbol)
219-
throw(ArgumentError("axis_name must be a Symbol"))
220-
end
217+
common_axes, trailing_axes = zip(_splitall(Val{N}, axisparams.(As)...)...)
218+
219+
foreach(_check_common_axes, zip(common_axes...))
220+
221+
new_common_axes = first(common_axes)
222+
flat_axis_eltype = _flat_axis_eltype(LType, trailing_axes)
223+
flat_axis_type = CategoricalVector{flat_axis_eltype, Vector{flat_axis_eltype}}
224+
225+
new_axes_type = Tuple{new_common_axes..., Axis{:flat, flat_axis_type}}
226+
new_eltype = Base.promote_eltype(As...)
221227

222-
new_data = cat(last_dim + 1, (view(A.data, repeated(:, last_dim + 1)...) for A in As)...)
223-
new_axis = flatten_axes(array_names, map(A -> axes(A)[last_dim+1:end], As))
228+
quote
229+
common_axes, trailing_axes = zip(_splitall(Val{N}, axes.(As)...)...)
224230

225-
# TODO: Consider creating a SortedVector axis when all flattened axes are Dimensional
226-
return AxisArray(new_data, common_axes..., CategoricalVector(new_axis))
231+
for common_axis_tuple in zip(common_axes...)
232+
if !isempty(common_axis_tuple)
233+
for common_axis in common_axis_tuple[2:end]
234+
if !all(axisvalues(common_axis) .== axisvalues(common_axis_tuple[1]))
235+
throw(ArgumentError(
236+
"""
237+
Leading common axes must be identical across
238+
all input arrays"""
239+
))
240+
end
241+
end
242+
end
243+
end
244+
245+
array_data = cat($flat_dim, _reshapeall($flat_dim, As...)...)
246+
247+
axis_array_type = AxisArray{
248+
$new_eltype,
249+
$flat_dim_int,
250+
Array{$new_eltype, $flat_dim_int},
251+
$new_axes_type
252+
}
253+
254+
new_axes = (
255+
first(common_axes)...,
256+
Axis{:flat, $flat_axis_type}($flat_axis_type(_flatten_axes(labels, trailing_axes))),
257+
)
258+
259+
return axis_array_type(array_data, new_axes)
260+
end
227261
end

src/core.jl

+9
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,15 @@ end
508508
axes(A::AbstractArray) = default_axes(A)
509509
axes(A::AbstractArray, dim::Int) = default_axes(A)[dim]
510510

511+
"""
512+
axisparams(::AxisArray) -> Vararg{::Type{Axis}}
513+
axisparams(::Type{AxisArray}) -> Vararg{::Type{Axis}}
514+
515+
Returns the axis parameters for an AxisArray.
516+
"""
517+
axisparams{T,N,D,Ax}(::AxisArray{T,N,D,Ax}) = (Ax.parameters...)
518+
axisparams{T,N,D,Ax}(::Type{AxisArray{T,N,D,Ax}}) = (Ax.parameters...)
519+
511520
### Axis traits ###
512521
@compat abstract type AxisTrait end
513522
immutable Dimensional <: AxisTrait end

test/combine.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,21 @@ ABdata[3:6,3:6,:,2] = Bdata
5252
A1 = AxisArray(A1data, Axis{:X}(1:2), Axis{:Y}(1:2))
5353
A2 = AxisArray(reshape(A2data, size(A2data)..., 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:Z}([:foo]))
5454

55-
@test flatten(A1, A2; array_names=[:A1, :A2]) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:page}(CategoricalVector([(:A1,), (:A2, :foo)])))
56-
@test flatten(A1; array_names=[:foo]) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:page}(CategoricalVector([(:foo,)])))
57-
@test flatten(A1; array_names=[:a], axis_name=:ax) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:ax}(CategoricalVector([(:a,)])))
55+
@test @inferred(flatten(Val{2}, A1, A2)) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(1,), (2, :foo)])))
56+
@test @inferred(flatten(Val{2}, A1)) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(1,)])))
57+
@test @inferred(flatten(Val{2}, A1)) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:flat}(CategoricalVector([(1,)])))
5858

59-
@test_throws ArgumentError flatten(-1, A1)
60-
@test_throws ArgumentError flatten(10, A1)
59+
@test @inferred(flatten(Val{2}, (:A1, :A2), A1, A2)) == AxisArray(cat(3, A1data, A2data), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(:A1,), (:A2, :foo)])))
60+
@test @inferred(flatten(Val{2}, (:foo,), A1)) == AxisArray(reshape(A1, 2, 2, 1), Axis{:X}(1:2), Axis{:Y}(1:2), Axis{:flat}(CategoricalVector([(:foo,)])))
61+
@test @inferred(flatten(Val{2}, (:a,), A1)) == AxisArray(reshape(A1.data, size(A1)..., 1), axes(A1)..., Axis{:flat}(CategoricalVector([(:a,)])))
62+
63+
@test @inferred(flatten(Val{0}, A1)) == AxisArray(vec(A1data), Axis{:flat}(CategoricalVector(collect(IterTools.product((1,), axisvalues(A1)...)))))
64+
@test @inferred(flatten(Val{1}, A1)) == AxisArray(A1data, Axis{:row}(1:2), Axis{:flat}(CategoricalVector(collect(IterTools.product((1,), axisvalues(A1)[2])))))
65+
66+
@test_throws ArgumentError flatten(Val{-1}, A1)
67+
@test_throws ArgumentError flatten(Val{10}, A1)
6168

6269
A1ᵀ = transpose(A1)
63-
@test flatten(A1, A1ᵀ) == flatten(0, A1, A1ᵀ)
64-
@test_throws ArgumentError flatten(-1, A1, A1ᵀ)
65-
@test_throws ArgumentError flatten(1, A1, A1ᵀ)
66-
@test_throws ArgumentError flatten(10, A1, A1ᵀ)
70+
@test_throws ArgumentError flatten(Val{-1}, A1, A1ᵀ)
71+
@test_throws ArgumentError flatten(Val{1}, A1, A1ᵀ)
72+
@test_throws ArgumentError flatten(Val{10}, A1, A1ᵀ)

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using AxisArrays
22
using Base.Test
3+
import IterTools
34

45
@testset "AxisArrays" begin
56
# during this time there was an ambiguity in base with checkbounds_linear_indices

0 commit comments

Comments
 (0)