Skip to content
This repository was archived by the owner on May 5, 2019. It is now read-only.

Commit d6104d7

Browse files
committed
Restore old grouping algorithm and improve it
Follow the strategy used by Pandas. The new implementation is more efficient since it avoids creating a NullableCategoricalArray: the integer codes are combined on the fly with those computed from previous columns. Hashing only happens once by giving arbitrary codes to levels in the first pass; after that, only integer codes are used. Move the per-column operations to separate functions which can be specialized by the compiler for each column type. This also allows using a more efficient method for CategoricalArray. Fix ordering of CategoricalArray levels when levels have been reordered, and sort null values last for consistency with other nullable arrays. Enable sorting by default since its cost is relatively small compared with the rest. Avoid some allocations by using in place operations, use Base.unique!().
1 parent 302d779 commit d6104d7

File tree

3 files changed

+176
-29
lines changed

3 files changed

+176
-29
lines changed

src/groupeddatatable/grouping.jl

+127-23
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,111 @@ end
2626
# Split
2727
#
2828

29+
function groupsort_indexer(x::AbstractVector, ngroups::Integer, null_last::Bool=false)
30+
# translated from Wes McKinney's groupsort_indexer in pandas (file: src/groupby.pyx).
31+
32+
# count group sizes, location 0 for NULL
33+
n = length(x)
34+
# counts = x.pool
35+
counts = fill(0, ngroups + 1)
36+
for i = 1:n
37+
counts[x[i] + 1] += 1
38+
end
39+
40+
# mark the start of each contiguous group of like-indexed data
41+
where = fill(1, ngroups + 1)
42+
if null_last
43+
for i = 3:ngroups+1
44+
where[i] = where[i - 1] + counts[i - 1]
45+
end
46+
where[1] = where[end] + counts[end]
47+
else
48+
for i = 2:ngroups+1
49+
where[i] = where[i - 1] + counts[i - 1]
50+
end
51+
end
52+
53+
# this is our indexer
54+
result = fill(0, n)
55+
for i = 1:n
56+
label = x[i] + 1
57+
result[where[label]] = i
58+
where[label] += 1
59+
end
60+
result, where, counts
61+
end
62+
63+
# Assign an integer code to each level of x, and combine these codes with existing vector
64+
function combine_col!{T}(x::AbstractVector, col::AbstractVector{T},
65+
ngroups::Integer, sort::Bool)
66+
d = Dict{T, UInt32}()
67+
y = Vector{UInt32}(length(x))
68+
n = 0
69+
# Note: using get! instead of triggers lots of allocations
70+
@inbounds for i in eachindex(x)
71+
v = col[i]
72+
index = Base.ht_keyindex(d, v)
73+
if index < 0 # new level
74+
@inbounds y[i] = d[v] = n
75+
n += 1
76+
else
77+
y[i] = d.vals[index]
78+
end
79+
end
80+
81+
if sort
82+
# compute mapping from unsorted to sorted codes
83+
tmp = sortperm(collect(keys(d)))
84+
perm = ipermute!(collect(0:(n-1)), tmp)
85+
refperm = sortperm!(tmp, collect(values(d)))
86+
permute!(perm, tmp)
87+
88+
@inbounds for i in eachindex(x)
89+
x[i] += perm[y[i] + 1] * ngroups
90+
end
91+
else
92+
@inbounds for i in eachindex(x)
93+
x[i] += y[i] * ngroups
94+
end
95+
end
96+
97+
n
98+
end
99+
100+
# More efficient method which can use the references directly
101+
# Levels are always sorted
102+
function combine_col!(x::AbstractVector,
103+
col::Union{AbstractCategoricalVector, AbstractNullableCategoricalVector},
104+
ngroups::Integer, sort::Bool)
105+
nlevels = length(levels(col))
106+
order = CategoricalArrays.order(col.pool)
107+
codes = similar(order, length(order)+1)
108+
codes[1] = nlevels # Sort nulls last, only used if present
109+
codes[2:end] .= order .- 1
110+
anynulls = false
111+
@inbounds for i in eachindex(x)
112+
ref = col.refs[i]
113+
x[i] += codes[ref + 1] * ngroups
114+
if eltype(col) <: Nullable
115+
anynulls |= (ref == 0)
116+
end
117+
end
118+
nlevels + anynulls
119+
end
120+
29121
"""
30122
A view of an AbstractDataTable split into row groups
31123
32124
```julia
33-
groupby(d::AbstractDataTable, cols)
34-
groupby(cols)
125+
groupby(d::AbstractDataTable, cols; sort = true)
126+
groupby(cols; sort = true)
35127
```
36128
37129
### Arguments
38130
39131
* `d` : an AbstractDataTable to split (optional, see [Returns](#returns))
40132
* `cols` : data table columns to group by
133+
* `sort`: whether to sort row groups; disable sorting for maximum performance
41134
42135
### Returns
43136
@@ -79,17 +172,24 @@ dt |> groupby([:a, :b]) |> [sum, length]
79172
```
80173
81174
"""
82-
function groupby{T}(dt::AbstractDataTable, cols::Vector{T}; sort::Bool = false)
83-
sdt = dt[cols]
84-
dt_groups = group_rows(sdt)
85-
# sort the groups
86-
if sort
87-
group_perm = sortperm(view(sdt, dt_groups.rperm[dt_groups.starts]))
88-
permute!(dt_groups.starts, group_perm)
89-
Base.permute!!(dt_groups.stops, group_perm)
175+
function groupby{T}(d::AbstractDataTable, cols::Vector{T}; sort::Bool = true)
176+
## a subset of Wes McKinney's algorithm here:
177+
## http://wesmckinney.com/blog/?p=489
178+
179+
x = ones(UInt32, nrow(d))
180+
ngroups = 1
181+
for j in length(cols):-1:1
182+
# also compute the number of groups, which is the product of the set lengths
183+
ngroups *= combine_col!(x, d[cols[j]], ngroups, sort)
184+
# TODO if ngroups is really big, shrink it
90185
end
91-
GroupedDataTable(dt, cols, dt_groups.rperm,
92-
dt_groups.starts, dt_groups.stops)
186+
(idx, starts) = groupsort_indexer(x, ngroups)
187+
# Remove zero-length groupings
188+
starts = _groupedunique!(starts)
189+
ends = starts[2:end]
190+
ends .-= 1
191+
pop!(starts)
192+
GroupedDataTable(d, cols, idx, starts, ends)
93193
end
94194
groupby(d::AbstractDataTable, cols; sort::Bool = false) = groupby(d, [cols], sort = sort)
95195

@@ -263,8 +363,8 @@ Split-apply-combine in one step; apply `f` to each grouping in `d`
263363
based on columns `col`
264364
265365
```julia
266-
by(d::AbstractDataTable, cols, f::Function; sort::Bool = false)
267-
by(f::Function, d::AbstractDataTable, cols; sort::Bool = false)
366+
by(d::AbstractDataTable, cols, f::Function; sort::Bool = true)
367+
by(f::Function, d::AbstractDataTable, cols; sort::Bool = true)
268368
```
269369
270370
### Arguments
@@ -273,7 +373,7 @@ by(f::Function, d::AbstractDataTable, cols; sort::Bool = false)
273373
* `cols` : a column indicator (Symbol, Int, Vector{Symbol}, etc.)
274374
* `f` : a function to be applied to groups; expects each argument to
275375
be an AbstractDataTable
276-
* `sort`: sort row groups (no sorting by default)
376+
* `sort`: whether to sort row groups; disable sorting for maximum performance
277377
278378
`f` can return a value, a vector, or a DataTable. For a value or
279379
vector, these are merged into a column along with the `cols` keys. For
@@ -321,8 +421,8 @@ Split-apply-combine that applies a set of functions over columns of an
321421
AbstractDataTable or GroupedDataTable
322422
323423
```julia
324-
aggregate(d::AbstractDataTable, cols, fs)
325-
aggregate(gd::GroupedDataTable, fs)
424+
aggregate(d::AbstractDataTable, cols, fs; sort::Bool=true)
425+
aggregate(gd::GroupedDataTable, fs; sort::Bool=true)
326426
```
327427
328428
### Arguments
@@ -332,6 +432,7 @@ aggregate(gd::GroupedDataTable, fs)
332432
* `cols` : a column indicator (Symbol, Int, Vector{Symbol}, etc.)
333433
* `fs` : a function or vector of functions to be applied to vectors
334434
within groups; expects each argument to be a column vector
435+
* `sort`: whether to sort row groups; disable sorting for maximum performance
335436
336437
Each `fs` should return a value or vector. All returns must be the
337438
same length.
@@ -353,15 +454,17 @@ dt |> groupby(:a) |> [sum, x->mean(dropnull(x))] # equivalent
353454
```
354455
355456
"""
356-
aggregate(d::AbstractDataTable, fs::Function; sort::Bool=false) = aggregate(d, [fs], sort=sort)
357-
function aggregate{T<:Function}(d::AbstractDataTable, fs::Vector{T}; sort::Bool=false)
457+
aggregate(d::AbstractDataTable, fs::Function; sort::Bool=true) =
458+
aggregate(d, [fs], sort=sort)
459+
function aggregate{T<:Function}(d::AbstractDataTable, fs::Vector{T}; sort::Bool=true)
358460
headers = _makeheaders(fs, _names(d))
359461
_aggregate(d, fs, headers, sort)
360462
end
361463

362464
# Applies aggregate to non-key cols of each SubDataTable of a GroupedDataTable
363-
aggregate(gd::GroupedDataTable, f::Function; sort::Bool=false) = aggregate(gd, [f], sort=sort)
364-
function aggregate{T<:Function}(gd::GroupedDataTable, fs::Vector{T}; sort::Bool=false)
465+
aggregate(gd::GroupedDataTable, f::Function; sort::Bool=true) =
466+
aggregate(gd, [f], sort=sort)
467+
function aggregate{T<:Function}(gd::GroupedDataTable, fs::Vector{T}; sort::Bool=true)
365468
headers = _makeheaders(fs, setdiff(_names(gd), gd.cols))
366469
res = combine(map(x -> _aggregate(without(x, gd.cols), fs, headers), gd))
367470
sort && sort!(res, cols=headers)
@@ -375,7 +478,7 @@ end
375478
function aggregate{S<:ColumnIndex, T <:Function}(d::AbstractDataTable,
376479
cols::Union{S, AbstractVector{S}},
377480
fs::Union{T, Vector{T}};
378-
sort::Bool=false)
481+
sort::Bool=true)
379482
aggregate(groupby(d, cols, sort=sort), fs)
380483
end
381484

@@ -384,7 +487,8 @@ function _makeheaders{T<:Function}(fs::Vector{T}, cn::Vector{Symbol})
384487
[Symbol(colname,'_',fname) for fname in fnames for colname in cn]
385488
end
386489

387-
function _aggregate{T<:Function}(d::AbstractDataTable, fs::Vector{T}, headers::Vector{Symbol}, sort::Bool=false)
490+
function _aggregate{T<:Function}(d::AbstractDataTable, fs::Vector{T},
491+
headers::Vector{Symbol}, sort::Bool=true)
388492
res = DataTable(Any[vcat(f(d[i])) for f in fs for i in 1:size(d, 2)], headers)
389493
sort && sort!(res, cols=headers)
390494
res

src/other/utils.jl

+20
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,26 @@ function countnull(a::CategoricalArray)
155155
return res
156156
end
157157

158+
if !isdefined(Base, :unique!) # Julia < 0.7
159+
function _groupedunique!(A::AbstractVector)
160+
isempty(A) && return A
161+
idxs = eachindex(A)
162+
y = first(A)
163+
state = start(idxs)
164+
i, state = next(idxs, state)
165+
for x in A
166+
if !isequal(x, y)
167+
i, state = next(idxs, state)
168+
y = A[i] = x
169+
end
170+
end
171+
resize!(A, i - first(idxs) + 1)
172+
end
173+
else
174+
# unique!() includes a fast path for sorted vectors
175+
_groupedunique!(A::AbstractVector) = unique!(A)
176+
end
177+
158178
# Gets the name of a function. Used in groupedatatable/grouping.jl
159179
function _fnames{T<:Function}(fs::Vector{T})
160180
λcounter = 0

test/grouping.jl

+29-6
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,34 @@ module TestGrouping
165165
levels!(dt[:Key1], ["Z", "B", "A"])
166166
levels!(dt[:Key2], ["Z", "B", "A"])
167167
gd = groupby(dt, :Key1)
168-
@test isequal(gd[1], DataTable(Key1=["A", "A"], Key2=["A", "B"], Value=1:2))
169-
@test isequal(gd[2], DataTable(Key1=["B", "B"], Key2=["A", "B"], Value=3:4))
168+
@test isequal(gd[1], DataTable(Key1=["B", "B"], Key2=["A", "B"], Value=3:4))
169+
@test isequal(gd[2], DataTable(Key1=["A", "A"], Key2=["A", "B"], Value=1:2))
170170
gd = groupby(dt, [:Key1, :Key2])
171-
@test isequal(gd[1], DataTable(Key1="A", Key2="A", Value=1))
172-
@test isequal(gd[2], DataTable(Key1="A", Key2="B", Value=2))
173-
@test isequal(gd[3], DataTable(Key1="B", Key2="A", Value=3))
174-
@test isequal(gd[4], DataTable(Key1="B", Key2="B", Value=4))
171+
@test isequal(gd[1], DataTable(Key1="B", Key2="B", Value=4))
172+
@test isequal(gd[2], DataTable(Key1="B", Key2="A", Value=3))
173+
@test isequal(gd[3], DataTable(Key1="A", Key2="B", Value=2))
174+
@test isequal(gd[4], DataTable(Key1="A", Key2="A", Value=1))
175+
176+
# test NullableArray and NullableCategoricalArray with nulls
177+
for (S, T) in ((NullableArray, NullableArray),
178+
(NullableCategoricalArray, NullableCategoricalArray),
179+
(NullableArray, NullableCategoricalArray),
180+
(NullableCategoricalArray, NullableArray))
181+
dt = DataTable(Key1 = S(["A", "A", "B", Nullable(), Nullable()]),
182+
Key2 = T(["A", "B", "A", Nullable(), "A"]),
183+
Value = 1:5)
184+
gd = groupby(dt, :Key1)
185+
@test isequal(gd[1], DataTable(Key1=Nullable{String}["A", "A"],
186+
Key2=Nullable{String}["A", "B"], Value=1:2))
187+
@test isequal(gd[2], DataTable(Key1=Nullable{String}["B"],
188+
Key2=Nullable{String}["A"], Value=3))
189+
@test isequal(gd[3], DataTable(Key1=[Nullable(), Nullable()],
190+
Key2=Nullable{String}[Nullable(), "A"], Value=4:5))
191+
gd = groupby(dt, [:Key1, :Key2])
192+
@test isequal(gd[1], DataTable(Key1=Nullable("A"), Key2=Nullable("A"), Value=1))
193+
@test isequal(gd[2], DataTable(Key1=Nullable("A"), Key2=Nullable("B"), Value=2))
194+
@test isequal(gd[3], DataTable(Key1=Nullable("B"), Key2=Nullable("A"), Value=3))
195+
@test isequal(gd[4], DataTable(Key1=Nullable(), Key2=Nullable("A"), Value=5))
196+
@test isequal(gd[5], DataTable(Key1=Nullable(), Key2=Nullable(), Value=4))
197+
end
175198
end

0 commit comments

Comments
 (0)