Skip to content

Commit f0743ce

Browse files
iamed2timholy
authored andcommitted
Add CategoricalVector and collapse (#88)
1 parent fe61993 commit f0743ce

10 files changed

+362
-4
lines changed

REQUIRE

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
julia 0.6
22
IntervalSets 0.1
3+
IterTools
34
RangeArrays

src/AxisArrays.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ __precompile__()
33
module AxisArrays
44

55
using Base: tail
6+
import Base.Iterators: repeated
67
using RangeArrays, IntervalSets
8+
using IterTools
79

8-
export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, atvalue
10+
export AxisArray, Axis, axisnames, axisvalues, axisdim, axes, atindex, atvalue, collapse
911

1012
# From IntervalSets:
1113
export ClosedInterval, ..
@@ -15,6 +17,7 @@ include("intervals.jl")
1517
include("search.jl")
1618
include("indexing.jl")
1719
include("sortedvector.jl")
20+
include("categoricalvector.jl")
1821
include("combine.jl")
1922

2023
end

src/categoricalvector.jl

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
A CategoricalVector is an AbstractVector which is treated as a categorical axis regardless
3+
of the element type. Duplicate values are not allowed but are not filtered out.
4+
5+
A CategoricalVector axis can be indexed with an ClosedInterval, with a value, or with a
6+
vector of values. Use of a CategoricalVector{Tuple} axis allows indexing similar to the
7+
hierarchical index of the Python Pandas package or the R data.table package.
8+
9+
In general, indexing into a CategoricalVector will be much slower than the corresponding
10+
SortedVector or another sorted axis type, as linear search is required.
11+
12+
### Constructors
13+
14+
```julia
15+
CategoricalVector(x::AbstractVector)
16+
```
17+
18+
### Arguments
19+
20+
* `x::AbstractVector` : the wrapped vector
21+
22+
### Examples
23+
24+
```julia
25+
v = CategoricalVector(collect([1; 8; 10:15]))
26+
A = AxisArray(reshape(1:16, 8, 2), v, [:a, :b])
27+
A[Axis{:row}(1), :]
28+
A[Axis{:row}(10), :]
29+
A[Axis{:row}([1, 10]), :]
30+
31+
## Hierarchical index example with three key levels
32+
33+
data = reshape(1.:40., 20, 2)
34+
v = collect(zip([:a, :b, :c][rand(1:3,20)], [:x,:y][rand(1:2,20)], [:x,:y][rand(1:2,20)]))
35+
A = AxisArray(data, CategoricalVector(v), [:a, :b])
36+
A[:b, :]
37+
A[[:a,:c], :]
38+
A[(:a,:x), :]
39+
A[(:a,:x,:x), :]
40+
```
41+
"""
42+
immutable CategoricalVector{T, A<:AbstractVector{T}} <: AbstractVector{T}
43+
data::A
44+
end
45+
46+
function CategoricalVector(data::AbstractVector{T}) where T
47+
CategoricalVector{T, typeof(data)}(data)
48+
end
49+
50+
Base.getindex(v::CategoricalVector, idx::Int) = v.data[idx]
51+
Base.getindex(v::CategoricalVector, idx::AbstractVector) = CategoricalVector(v.data[idx])
52+
53+
Base.length(v::CategoricalVector) = length(v.data)
54+
Base.size(v::CategoricalVector) = size(v.data)
55+
Base.size(v::CategoricalVector, i) = size(v.data, i)
56+
Base.indices(v::CategoricalVector) = indices(v.data)
57+
58+
axistrait(::Type{CategoricalVector{T,A}}) where {T,A} = Categorical
59+
checkaxis(::CategoricalVector) = nothing
60+
61+
62+
## Add some special indexing for CategoricalVector{Tuple}'s to achieve something like
63+
## Panda's hierarchical indexing
64+
65+
axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx) = axisindexes(ax, (idx,))
66+
67+
function axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx::Tuple)
68+
collect(filter(ax_idx->_tuple_matches(ax.val[ax_idx], idx), indices(ax.val)...))
69+
end
70+
71+
function _tuple_matches(element::Tuple, idx::Tuple)
72+
length(idx) <= length(element) || return false
73+
74+
for (x, y) in zip(element, idx)
75+
x == y || return false
76+
end
77+
78+
return true
79+
end
80+
81+
axisindexes{T<:Tuple,S,A}(ax::Axis{S,CategoricalVector{T,A}}, idx::AbstractArray) =
82+
vcat([axisindexes(ax, i) for i in idx]...)

src/combine.jl

+194
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,197 @@ function Base.join{T,N,D,Ax}(As::AxisArray{T,N,D,Ax}...; fillvalue::T=zero(T),
139139
return result
140140

141141
end #join
142+
143+
function _collapse_array_axes(array_name, array_axes...)
144+
((array_name, (idx isa Tuple ? idx : (idx,))...) for idx in product((Ax.val for Ax in array_axes)...))
145+
end
146+
147+
function _collapse_axes(array_names, array_axes)
148+
collect(Iterators.flatten(map(array_names, array_axes) do tup_name, tup_array_axes
149+
_collapse_array_axes(tup_name, tup_array_axes...)
150+
end))
151+
end
152+
153+
function _splitall{N}(::Type{Val{N}}, As...)
154+
tuple((Base.IteratorsMD.split(A, Val{N}) for A in As)...)
155+
end
156+
157+
function _reshapeall{N}(::Type{Val{N}}, As...)
158+
tuple((reshape(A, Val{N}) for A in As)...)
159+
end
160+
161+
function _check_common_axes(common_axis_tuple)
162+
if !all(axisname(first(common_axis_tuple)) .=== axisname.(common_axis_tuple[2:end]))
163+
throw(ArgumentError("Leading common axes must have the same name in each array"))
164+
end
165+
166+
return nothing
167+
end
168+
169+
function _collapsed_axis_eltype(LType, trailing_axes)
170+
eltypes = map(trailing_axes) do array_trailing_axes
171+
Tuple{LType, eltype.(array_trailing_axes)...}
172+
end
173+
174+
return typejoin(eltypes...)
175+
end
176+
177+
function collapse{N, AN}(::Type{Val{N}}, As::Vararg{AxisArray, AN})
178+
collapse(Val{N}, ntuple(identity, Val{AN}), As...)
179+
end
180+
181+
function collapse{N, AN, NewArrayType<:AbstractArray}(::Type{Val{N}}, ::Type{NewArrayType}, As::Vararg{AxisArray, AN})
182+
collapse(Val{N}, NewArrayType, ntuple(identity, Val{AN}), As...)
183+
end
184+
185+
@generated function collapse{N, AN, LType}(::Type{Val{N}}, labels::NTuple{AN, LType}, As::Vararg{AxisArray, AN})
186+
collapsed_dim_int = Int(N) + 1
187+
new_eltype = Base.promote_eltype(As...)
188+
189+
quote
190+
collapse(Val{N}, Array{$new_eltype, $collapsed_dim_int}, labels, As...)
191+
end
192+
end
193+
194+
"""
195+
collapse(::Type{Val{N}}, As::AxisArray...) -> AxisArray
196+
collapse(::Type{Val{N}}, labels::Tuple, As::AxisArray...) -> AxisArray
197+
collapse(::Type{Val{N}}, ::Type{NewArrayType}, As::AxisArray...) -> AxisArray
198+
collapse(::Type{Val{N}}, ::Type{NewArrayType}, labels::Tuple, As::AxisArray...) -> AxisArray
199+
200+
Collapses `AxisArray`s with `N` equal leading axes into a single `AxisArray`.
201+
All additional axes in any of the arrays are collapsed into a single additional
202+
axis of type `Axis{:collapsed, CategoricalVector{Tuple}}`.
203+
204+
### Arguments
205+
206+
* `::Type{Val{N}}`: the greatest common dimension to share between all input
207+
arrays. The remaining axes are collapsed. All `N` axes must be common
208+
to each input array, at the same dimension. Values from `0` up to the
209+
minimum number of dimensions across all input arrays are allowed.
210+
* `labels::Tuple`: (optional) an index for each array in `As` used as the leading element in
211+
the index tuples in the `:collapsed` axis. Defaults to `1:length(As)`.
212+
* `::Type{NewArrayType<:AbstractArray{_, N+1}}`: (optional) the desired underlying array
213+
type for the returned `AxisArray`.
214+
* `As::AxisArray...`: `AxisArray`s to be collapsed together.
215+
216+
### Examples
217+
218+
```
219+
julia> price_data = AxisArray(rand(10), Axis{:time}(Date(2016,01,01):Day(1):Date(2016,01,10)))
220+
1-dimensional AxisArray{Float64,1,...} with axes:
221+
:time, 2016-01-01:1 day:2016-01-10
222+
And data, a 10-element Array{Float64,1}:
223+
0.885014
224+
0.418562
225+
0.609344
226+
0.72221
227+
0.43656
228+
0.840304
229+
0.455337
230+
0.65954
231+
0.393801
232+
0.260207
233+
234+
julia> size_data = AxisArray(rand(10,2), Axis{:time}(Date(2016,01,01):Day(1):Date(2016,01,10)), Axis{:measure}([:area, :volume]))
235+
2-dimensional AxisArray{Float64,2,...} with axes:
236+
:time, 2016-01-01:1 day:2016-01-10
237+
:measure, Symbol[:area, :volume]
238+
And data, a 10×2 Array{Float64,2}:
239+
0.159434 0.456992
240+
0.344521 0.374623
241+
0.522077 0.313256
242+
0.994697 0.320953
243+
0.95104 0.900526
244+
0.921854 0.729311
245+
0.000922581 0.148822
246+
0.449128 0.761714
247+
0.650277 0.135061
248+
0.688773 0.513845
249+
250+
julia> collapsed = collapse(Val{1}, (:price, :size), price_data, size_data)
251+
2-dimensional AxisArray{Float64,2,...} with axes:
252+
:time, 2016-01-01:1 day:2016-01-10
253+
:collapsed, Tuple{Symbol,Vararg{Symbol,N} where N}[(:price,), (:size, :area), (:size, :volume)]
254+
And data, a 10×3 Array{Float64,2}:
255+
0.885014 0.159434 0.456992
256+
0.418562 0.344521 0.374623
257+
0.609344 0.522077 0.313256
258+
0.72221 0.994697 0.320953
259+
0.43656 0.95104 0.900526
260+
0.840304 0.921854 0.729311
261+
0.455337 0.000922581 0.148822
262+
0.65954 0.449128 0.761714
263+
0.393801 0.650277 0.135061
264+
0.260207 0.688773 0.513845
265+
266+
julia> collapsed[Axis{:collapsed}(:size)] == size_data
267+
true
268+
```
269+
270+
"""
271+
@generated function collapse(::Type{Val{N}},
272+
::Type{NewArrayType},
273+
labels::NTuple{AN, LType},
274+
As::Vararg{AxisArray, AN}) where {N, AN, LType, NewArrayType<:AbstractArray}
275+
if N < 0
276+
throw(ArgumentError("collapse dimension N must be at least 0"))
277+
end
278+
279+
if N > minimum(ndims.(As))
280+
throw(ArgumentError(
281+
"""
282+
collapse dimension N must not be greater than the maximum number of dimensions
283+
across all input arrays
284+
"""
285+
))
286+
end
287+
288+
collapsed_dim = Val{N + 1}
289+
collapsed_dim_int = Int(N) + 1
290+
291+
common_axes, trailing_axes = zip(_splitall(Val{N}, axisparams.(As)...)...)
292+
293+
foreach(_check_common_axes, zip(common_axes...))
294+
295+
new_common_axes = first(common_axes)
296+
collapsed_axis_eltype = _collapsed_axis_eltype(LType, trailing_axes)
297+
collapsed_axis_type = CategoricalVector{collapsed_axis_eltype, Vector{collapsed_axis_eltype}}
298+
299+
new_axes_type = Tuple{new_common_axes..., Axis{:collapsed, collapsed_axis_type}}
300+
new_eltype = Base.promote_eltype(As...)
301+
302+
quote
303+
common_axes, trailing_axes = zip(_splitall(Val{N}, axes.(As)...)...)
304+
305+
for common_axis_tuple in zip(common_axes...)
306+
if !isempty(common_axis_tuple)
307+
for common_axis in common_axis_tuple[2:end]
308+
if !all(axisvalues(common_axis) .== axisvalues(common_axis_tuple[1]))
309+
throw(ArgumentError(
310+
"""
311+
Leading common axes must be identical across
312+
all input arrays"""
313+
))
314+
end
315+
end
316+
end
317+
end
318+
319+
array_data = cat($collapsed_dim, _reshapeall($collapsed_dim, As...)...)
320+
321+
axis_array_type = AxisArray{
322+
$new_eltype,
323+
$collapsed_dim_int,
324+
$NewArrayType,
325+
$new_axes_type
326+
}
327+
328+
new_axes = (
329+
first(common_axes)...,
330+
Axis{:collapsed, $collapsed_axis_type}($collapsed_axis_type(_collapse_axes(labels, trailing_axes))),
331+
)
332+
333+
return axis_array_type(array_data, new_axes)
334+
end
335+
end

src/core.jl

+9
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,15 @@ end
503503
axes(A::AbstractArray) = default_axes(A)
504504
axes(A::AbstractArray, dim::Int) = default_axes(A)[dim]
505505

506+
"""
507+
axisparams(::AxisArray) -> Vararg{::Type{Axis}}
508+
axisparams(::Type{AxisArray}) -> Vararg{::Type{Axis}}
509+
510+
Returns the axis parameters for an AxisArray.
511+
"""
512+
axisparams{T,N,D,Ax}(::AxisArray{T,N,D,Ax}) = (Ax.parameters...)
513+
axisparams{T,N,D,Ax}(::Type{AxisArray{T,N,D,Ax}}) = (Ax.parameters...)
514+
506515
### Axis traits ###
507516
abstract type AxisTrait end
508517
immutable Dimensional <: AxisTrait end

src/indexing.jl

+16-1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,17 @@ end
231231
ex = Expr(:tuple)
232232
n = 0
233233
for i=1:length(I)
234+
if axistrait(I[i]) <: Categorical && i <= length(Ax.parameters)
235+
if I[i] <: Axis
236+
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
237+
else
238+
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
239+
end
240+
n += 1
241+
242+
continue
243+
end
244+
234245
if I[i] <: Idx
235246
push!(ex.args, :(I[$i]))
236247
n += 1
@@ -243,7 +254,11 @@ end
243254
end
244255
n += length(I[i])
245256
elseif i <= length(Ax.parameters)
246-
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
257+
if I[i] <: Axis
258+
push!(ex.args, :(axisindexes(A.axes[$i], I[$i].val)))
259+
else
260+
push!(ex.args, :(axisindexes(A.axes[$i], I[$i])))
261+
end
247262
n += 1
248263
else
249264
push!(ex.args, :(error("dimension ", $i, " does not have an axis to index")))

test/categoricalvector.jl

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Test CategoricalVector with a hierarchical index (indexed using Tuples)
2+
srand(1234)
3+
data = reshape(1.:40., 20, 2)
4+
v = collect(zip([:a, :b, :c][rand(1:3,20)], [:x,:y][rand(1:2,20)], [:x,:y][rand(1:2,20)]))
5+
idx = sortperm(v)
6+
A = AxisArray(data[idx,:], AxisArrays.CategoricalVector(v[idx]), [:a, :b])
7+
@test A[:b, :] == A[5:12, :]
8+
@test A[[:a,:c], :] == A[[1:4;13:end], :]
9+
@test A[(:a,:y), :] == A[2:4, :]
10+
@test A[(:c,:y,:y), :] == A[16:end, :]
11+
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical
12+
13+
v = AxisArrays.CategoricalVector(collect([1; 8; 10:15]))
14+
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical
15+
A = AxisArray(reshape(1:16, 8, 2), v, [:a, :b])
16+
@test A[Axis{:row}(AxisArrays.CategoricalVector([15]))] == AxisArray(reshape(A.data[8, :], 1, 2), AxisArrays.CategoricalVector([15]), [:a, :b])
17+
@test A[Axis{:row}(AxisArrays.CategoricalVector([15])), 1] == AxisArray([A.data[8, 1]], AxisArrays.CategoricalVector([15]))
18+
@test AxisArrays.axistrait(axes(A)[1]) <: AxisArrays.Categorical
19+
20+
# TODO: maybe make this work? Would require removing or modifying Base.getindex(A::AxisArray, idxs::Idx...)
21+
# @test A[AxisArrays.CategoricalVector([15]), 1] == AxisArray([A.data[8, 1]], AxisArrays.CategoricalVector([15]))

0 commit comments

Comments
 (0)