This repository was archived by the owner on May 5, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutils.jl
190 lines (171 loc) · 6.15 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Rows grouping.
# Maps row contents to the indices of all the equal rows.
# Used by groupby(), join(), nonunique()
immutable RowGroupDict{T<:AbstractDataTable}
dt::T # source data frame
ngroups::Int # number of groups
rhashes::Vector{UInt} # row hashes
gslots::Vector{Int} # hashindex -> index of group-representative row
groups::Vector{Int} # group index for each row
rperm::Vector{Int} # permutation of row indices that sorts them by groups
starts::Vector{Int} # starts of ranges in rperm for each group
stops::Vector{Int} # stops of ranges in rperm for each group
end
# "kernel" functions for hashrows()
# adjust row hashes by the hashes of column elements
function hashrows_col!(h::Vector{UInt}, v::AbstractVector)
@inbounds for i in eachindex(h)
h[i] = hash(v[i], h[i])
end
h
end
function hashrows_col!{T<:Nullable}(h::Vector{UInt}, v::AbstractVector{T})
@inbounds for i in eachindex(h)
h[i] = isnull(v[i]) ?
hash(Base.nullablehash_seed, h[i]) :
hash(unsafe_get(v[i]), h[i])
end
h
end
# should give the same hash as AbstractVector{T}
function hashrows_col!{T}(h::Vector{UInt}, v::AbstractCategoricalVector{T})
# TODO is it possible to optimize by hashing the pool values once?
@inbounds for (i, ref) in enumerate(v.refs)
h[i] = hash(CategoricalArrays.index(v.pool)[ref], h[i])
end
h
end
# should give the same hash as AbstractNullableVector{T}
function hashrows_col!{T}(h::Vector{UInt}, v::AbstractNullableCategoricalVector{T})
# TODO is it possible to optimize by hashing the pool values once?
@inbounds for (i, ref) in enumerate(v.refs)
h[i] = ref == 0 ?
hash(Base.nullablehash_seed, h[i]) :
hash(CategoricalArrays.index(v.pool)[ref], h[i])
end
h
end
# Calculate hash for each row
# in an efficient column-wise manner
function hashrows(dt::AbstractDataTable)
res = zeros(UInt, nrow(dt))
for col in columns(dt)
hashrows_col!(res, col)
end
return res
end
# Helper function for RowGroupDict.
# Returns a tuple:
# 1) the number of row groups in a data frame
# 2) vector of row hashes
# 3) slot array for a hash map, non-zero values are
# the indices of the first row in a group
# Optional group vector is set to the group indices of each row
function row_group_slots(dt::AbstractDataTable,
groups::Union{Vector{Int}, Void} = nothing)
@assert groups === nothing || length(groups) == nrow(dt)
rhashes = hashrows(dt)
sz = Base._tablesz(length(rhashes))
@assert sz >= length(rhashes)
szm1 = sz-1
gslots = zeros(Int, sz)
ngroups = 0
@inbounds for i in eachindex(rhashes)
# find the slot and group index for a row
slotix = rhashes[i] & szm1 + 1
gix = 0
probe = 0
while true
g_row = gslots[slotix]
if g_row == 0 # unoccupied slot, current row starts a new group
gslots[slotix] = i
gix = ngroups += 1
break
elseif rhashes[i] == rhashes[g_row] # occupied slot, check if miss or hit
eq = true
for col in columns(dt)
if !isequal_colel(col, i, g_row)
eq = false # miss
break
end
end
if eq # hit
gix = groups !== nothing ? groups[g_row] : 0
break
end
end
slotix = slotix & szm1 + 1 # check the next slot
probe += 1
@assert probe < sz
end
if groups !== nothing
groups[i] = gix
end
end
return ngroups, rhashes, gslots
end
# Builds RowGroupDict for a given datatable.
# Partly uses the code of Wes McKinney's groupsort_indexer in pandas (file: src/groupby.pyx).
function group_rows(dt::AbstractDataTable)
groups = Vector{Int}(nrow(dt))
ngroups, rhashes, gslots = row_group_slots(dt, groups)
# count elements in each group
stops = zeros(Int, ngroups)
@inbounds for g_ix in groups
stops[g_ix] += 1
end
# group start positions in a sorted frame
starts = Vector{Int}(ngroups)
if !isempty(starts)
starts[1] = 1
@inbounds for i in 1:(ngroups-1)
starts[i+1] = starts[i] + stops[i]
end
end
# define row permutation that sorts them into groups
rperm = Vector{Int}(length(groups))
copy!(stops, starts)
@inbounds for (i, gix) in enumerate(groups)
rperm[stops[gix]] = i
stops[gix] += 1
end
stops .-= 1
return RowGroupDict(dt, ngroups, rhashes, gslots, groups, rperm, starts, stops)
end
# number of unique row groups
ngroups(gd::RowGroupDict) = gd.ngroups
# Find index of a row in gd that matches given row by content, 0 if not found
function findrow(gd::RowGroupDict, dt::DataTable, row::Int)
(gd.dt === dt) && return row # same frame, return itself
# different frames, content matching required
rhash = rowhash(dt, row)
szm1 = length(gd.gslots)-1
slotix = ini_slotix = rhash & szm1 + 1
while true
g_row = gd.gslots[slotix]
if g_row == 0 || # not found
(rhash == gd.rhashes[g_row] &&
isequal_row(gd.dt, g_row, dt, row)) # found
return g_row
end
slotix = (slotix & szm1) + 1 # miss, try the next slot
(slotix == ini_slotix) && break
end
return 0 # not found
end
# Find indices of rows in 'gd' that match given row by content.
# return empty set if no row matches
function Base.get(gd::RowGroupDict, dt::DataTable, row::Int)
g_row = findrow(gd, dt, row)
(g_row == 0) && return view(gd.rperm, 0:-1)
gix = gd.groups[g_row]
return view(gd.rperm, gd.starts[gix]:gd.stops[gix])
end
function Base.getindex(gd::RowGroupDict, dtr::DataTableRow)
g_row = findrow(gd, dtr.dt, dtr.row)
(g_row == 0) && throw(KeyError(dtr))
gix = gd.groups[g_row]
return view(gd.rperm, gd.starts[gix]:gd.stops[gix])
end
# Check if there is matching row in gd
Base.in(gd::RowGroupDict, dt::DataTable, row::Int) = (findrow(gd, dt, row) != 0)