Skip to content
This repository was archived by the owner on May 5, 2019. It is now read-only.

Commit f5007b5

Browse files
committed
Specialize row_group_slots() and findrow() on column types to improve performance
Looping over columns is very slow when their type is unknown at compile time. Specialize the method on the types of the key (grouping) columns by passing a tuple of columns rather than a DataTable. This will force compiling a specific method for each combination of key types, but their number should remain relatively low and the one-time cost is worth it. This dramatically improves performance of groupby(), but does not have a large effect on join() since it is very inefficient in other areas. Also add return type assertion for rowhash(). The fact that the type of the columns isn't known at compile time appears to confuse inference, which isn't able to detect that this function always returns UInt. This reduces a lot the number of allocations when calling join(), but doesn't really change performance.
1 parent 3875020 commit f5007b5

File tree

4 files changed

+54
-29
lines changed

4 files changed

+54
-29
lines changed

src/abstractdatatable/join.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ function update_row_maps!(left_table::AbstractDataTable,
141141
@inline update!(mask::Vector{Bool}, orig_ixs::AbstractArray) = (mask[orig_ixs] = false)
142142

143143
# iterate over left rows and compose the left<->right index map
144+
right_dict_cols = ntuple(i -> right_dict.dt[i], ncol(right_dict.dt))
145+
left_table_cols = ntuple(i -> left_table[i], ncol(left_table))
144146
next_join_ix = 1
145147
for l_ix in 1:nrow(left_table)
146-
r_ixs = findrows(right_dict, left_table, l_ix)
148+
r_ixs = findrows(right_dict, left_table, right_dict_cols, left_table_cols, l_ix)
147149
if isempty(r_ixs)
148150
update!(leftonly_ixs, l_ix, next_join_ix)
149151
next_join_ix += 1
@@ -284,8 +286,10 @@ function Base.join(dt1::AbstractDataTable,
284286
# iterate over left rows and leave those found in right
285287
left_ixs = Vector{Int}()
286288
sizehint!(left_ixs, nrow(joiner.dtl))
289+
dtr_on_grp_cols = ntuple(i -> dtr_on_grp.dt[i], ncol(dtr_on_grp.dt))
290+
dtl_on_cols = ntuple(i -> joiner.dtl_on[i], ncol(joiner.dtl_on))
287291
@inbounds for l_ix in 1:nrow(joiner.dtl_on)
288-
if findrow(dtr_on_grp, joiner.dtl_on, l_ix) != 0
292+
if findrow(dtr_on_grp, joiner.dtl_on, dtr_on_grp_cols, dtl_on_cols, l_ix) != 0
289293
push!(left_ixs, l_ix)
290294
end
291295
end
@@ -296,8 +300,10 @@ function Base.join(dt1::AbstractDataTable,
296300
# iterate over left rows and leave those not found in right
297301
leftonly_ixs = Vector{Int}()
298302
sizehint!(leftonly_ixs, nrow(joiner.dtl))
303+
dtr_on_grp_cols = ntuple(i -> dtr_on_grp.dt[i], ncol(dtr_on_grp.dt))
304+
dtl_on_cols = ntuple(i -> joiner.dtl_on[i], ncol(joiner.dtl_on))
299305
@inbounds for l_ix in 1:nrow(joiner.dtl_on)
300-
if findrow(dtr_on_grp, joiner.dtl_on, l_ix) == 0
306+
if findrow(dtr_on_grp, joiner.dtl_on, dtr_on_grp_cols, dtl_on_cols, l_ix) == 0
301307
push!(leftonly_ixs, l_ix)
302308
end
303309
end

src/datatablerow/datatablerow.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,16 @@ end
5252

5353
# hash of DataTable rows based on its values
5454
# so that duplicate rows would have the same hash
55-
function rowhash(dt::DataTable, r::Int, h::UInt = zero(UInt))
56-
@inbounds for col in columns(dt)
57-
h = hash_colel(col, r, h)
58-
end
59-
return h
55+
# table columns are passed as a tuple of vectors to ensure type specialization
56+
rowhash(cols::Tuple{AbstractVector}, r::Int, h::UInt = zero(UInt))::UInt =
57+
hash_colel(cols[1], r, h)
58+
function rowhash(cols::Tuple{Vararg{AbstractVector}}, r::Int, h::UInt = zero(UInt))::UInt
59+
h = hash_colel(cols[1], r, h)
60+
rowhash(Base.tail(cols), r, h)
6061
end
6162

62-
Base.hash(r::DataTableRow, h::UInt = zero(UInt)) = rowhash(r.dt, r.row, h)
63+
Base.hash(r::DataTableRow, h::UInt = zero(UInt)) =
64+
rowhash(ntuple(i -> r.dt[i], ncol(r.dt)), r.row, h)
6365

6466
# comparison of DataTable rows
6567
# only the rows of the same DataTable could be compared
@@ -81,6 +83,19 @@ isequal_colel(a::Nullable, b::Any) = !isnull(a) & isequal(unsafe_get(a), b)
8183
isequal_colel(a::Any, b::Nullable) = isequal_colel(b, a)
8284
isequal_colel(a::Nullable, b::Nullable) = isequal(a, b)
8385

86+
# table columns are passed as a tuple of vectors to ensure type specialization
87+
isequal_row(cols::Tuple{AbstractVector}, r1::Int, r2::Int) =
88+
isequal_colel(cols[1][r1], cols[1][r2])
89+
isequal_row(cols::Tuple{Vararg{AbstractVector}}, r1::Int, r2::Int) =
90+
isequal_colel(cols[1][r1], cols[1][r2]) && isequal_row(Base.tail(cols), r1, r2)
91+
92+
isequal_row(cols1::Tuple{AbstractVector}, r1::Int, cols2::Tuple{AbstractVector}, r2::Int) =
93+
isequal_colel(cols1[1][r1], cols2[1][r2])
94+
isequal_row(cols1::Tuple{Vararg{AbstractVector}}, r1::Int,
95+
cols2::Tuple{Vararg{AbstractVector}}, r2::Int) =
96+
isequal_colel(cols1[1][r1], cols2[1][r2]) &&
97+
isequal_row(Base.tail(cols1), r1, Base.tail(cols2), r2)
98+
8499
function isequal_row(dt1::AbstractDataTable, r1::Int, dt2::AbstractDataTable, r2::Int)
85100
if dt1 === dt2
86101
if r1 == r2

src/datatablerow/utils.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,13 @@ end
8080
# 3) slot array for a hash map, non-zero values are
8181
# the indices of the first row in a group
8282
# Optional group vector is set to the group indices of each row
83-
function row_group_slots(dt::AbstractDataTable,
83+
row_group_slots(dt::AbstractDataTable, groups::Union{Vector{Int}, Void} = nothing) =
84+
row_group_slots(ntuple(i -> dt[i], ncol(dt)), hashrows(dt), groups)
85+
86+
function row_group_slots(cols::Tuple{Vararg{AbstractVector}},
87+
rhashes::AbstractVector{UInt},
8488
groups::Union{Vector{Int}, Void} = nothing)
85-
@assert groups === nothing || length(groups) == nrow(dt)
86-
rhashes = hashrows(dt)
89+
@assert groups === nothing || length(groups) == length(cols[1])
8790
# inspired by Dict code from base cf. https://github.com/JuliaData/DataTables.jl/pull/17#discussion_r102481481
8891
sz = Base._tablesz(length(rhashes))
8992
@assert sz >= length(rhashes)
@@ -102,17 +105,10 @@ function row_group_slots(dt::AbstractDataTable,
102105
gix = ngroups += 1
103106
break
104107
elseif rhashes[i] == rhashes[g_row] # occupied slot, check if miss or hit
105-
eq = true
106-
for col in columns(dt)
107-
if !isequal_colel(col, i, g_row)
108-
eq = false # miss
109-
break
110-
end
111-
end
112-
if eq # hit
108+
if isequal_row(cols, i, g_row) # hit
113109
gix = groups !== nothing ? groups[g_row] : 0
114-
break
115110
end
111+
break
116112
end
117113
slotix = slotix & szm1 + 1 # check the next slot
118114
probe += 1
@@ -158,17 +154,21 @@ function group_rows(dt::AbstractDataTable)
158154
end
159155

160156
# Find index of a row in gd that matches given row by content, 0 if not found
161-
function findrow(gd::RowGroupDict, dt::DataTable, row::Int)
157+
function findrow(gd::RowGroupDict,
158+
dt::DataTable,
159+
gd_cols::Tuple{Vararg{AbstractVector}},
160+
dt_cols::Tuple{Vararg{AbstractVector}},
161+
row::Int)
162162
(gd.dt === dt) && return row # same table, return itself
163163
# different tables, content matching required
164-
rhash = rowhash(dt, row)
164+
rhash = rowhash(dt_cols, row)
165165
szm1 = length(gd.gslots)-1
166166
slotix = ini_slotix = rhash & szm1 + 1
167167
while true
168168
g_row = gd.gslots[slotix]
169169
if g_row == 0 || # not found
170170
(rhash == gd.rhashes[g_row] &&
171-
isequal_row(gd.dt, g_row, dt, row)) # found
171+
isequal_row(gd_cols, g_row, dt_cols, row)) # found
172172
return g_row
173173
end
174174
slotix = (slotix & szm1) + 1 # miss, try the next slot
@@ -179,15 +179,20 @@ end
179179

180180
# Find indices of rows in 'gd' that match given row by content.
181181
# return empty set if no row matches
182-
function findrows(gd::RowGroupDict, dt::DataTable, row::Int)
183-
g_row = findrow(gd, dt, row)
182+
function findrows(gd::RowGroupDict,
183+
dt::DataTable,
184+
gd_cols::Tuple{Vararg{AbstractVector}},
185+
dt_cols::Tuple{Vararg{AbstractVector}},
186+
row::Int)
187+
g_row = findrow(gd, dt, gd_cols, dt_cols, row)
184188
(g_row == 0) && return view(gd.rperm, 0:-1)
185189
gix = gd.groups[g_row]
186190
return view(gd.rperm, gd.starts[gix]:gd.stops[gix])
187191
end
188192

189193
function Base.getindex(gd::RowGroupDict, dtr::DataTableRow)
190-
g_row = findrow(gd, dtr.dt, dtr.row)
194+
g_row = findrow(gd, dtr.dt, ntuple(i -> gd.dt[i], ncol(gd.dt)),
195+
ntuple(i -> dtr.dt[i], ncol(dtr.dt)), dtr.row)
191196
(g_row == 0) && throw(KeyError(dtr))
192197
gix = gd.groups[g_row]
193198
return view(gd.rperm, gd.starts[gix]:gd.stops[gix])

test/datatablerow.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ module TestDataTableRow
6767
# getting groups for the rows of the other frames
6868
@test length(gd[DataTableRow(dt6, 1)]) > 0
6969
@test_throws KeyError gd[DataTableRow(dt6, 2)]
70-
@test isempty(DataTables.findrows(gd, dt6, 2))
71-
@test length(DataTables.findrows(gd, dt6, 2)) == 0
70+
@test isempty(DataTables.findrows(gd, dt6, (gd.dt[1],), (dt6[1],), 2))
7271

7372
# grouping empty frame
7473
gd = DataTables.group_rows(DataTable(x=Int[]))

0 commit comments

Comments
 (0)