Skip to content

Commit eac2650

Browse files
author
Pietro Vertechi
authored
implement efficient sortperm (JuliaArrays#31)
1 parent b120387 commit eac2650

File tree

3 files changed

+173
-7
lines changed

3 files changed

+173
-7
lines changed

src/StructArrays.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ include("sort.jl")
1212
function __init__()
1313
Requires.@require Tables="bd369af6-aec1-5ad0-b16a-f7cc5008161c" include("tables.jl")
1414
Requires.@require PooledArrays="2dfb63ee-cc39-5dd5-95bd-886bf059d720" begin
15-
isdiscrete(::PooledArrays.PooledArray) = true
15+
ispooledarray(::PooledArrays.PooledArray) = true
1616
end
1717
Requires.@require WeakRefStrings="ea10d353-3f73-51f8-a26c-33c1cb351aa5" begin
18-
isdiscrete(::WeakRefStrings.StringArray) = true
18+
isstringarray(::WeakRefStrings.StringArray) = true
1919
end
2020
end
2121

src/sort.jl

+127-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,137 @@
1-
isdiscrete(v) = false
1+
using Base.Sort, Base.Order
2+
3+
isstringarray(::AbstractArray) = false
4+
ispooledarray(::AbstractArray) = false
25

36
function Base.permute!(c::StructVector, p::AbstractVector)
47
foreachfield(c) do v
5-
if isdiscrete(v) || v isa StructVector
8+
if v isa StructVector || isstringarray(v) || ispooledarray(v)
69
permute!(v, p)
710
else
811
copyto!(v, v[p])
912
end
1013
end
1114
return c
1215
end
16+
17+
struct TiedIndices{T <: AbstractVector}
18+
vec::T
19+
perm::Vector{Int}
20+
within::Tuple{Int, Int}
21+
end
22+
23+
TiedIndices(vec::AbstractVector, perm=sortperm(vec)) =
24+
TiedIndices(vec, perm, extrema(axes(vec, 1)))
25+
26+
Base.IteratorSize(::Type{<:TiedIndices}) = Base.SizeUnknown()
27+
28+
Base.eltype(::Type{<:TiedIndices{T}}) where {T} =
29+
Pair{eltype(T), UnitRange{Int}}
30+
31+
function Base.iterate(n::TiedIndices, i = n.within[1])
32+
vec, perm = n.vec, n.perm
33+
l = n.within[2]
34+
i > l && return nothing
35+
row = vec[perm[i]]
36+
i1 = i
37+
@inbounds while i1 <= l && isequal(row, vec[perm[i1]])
38+
i1 += 1
39+
end
40+
return (row => i:(i1-1), i1)
41+
end
42+
43+
tiedindices(args...) = TiedIndices(args...)
44+
45+
function groupindices(args...)
46+
t = tiedindices(args...)
47+
p = t.perm
48+
((row => t.perm[idxs]) for (row, idxs) in t)
49+
end
50+
51+
function Base.sortperm(c::StructVector{T};
52+
alg = DEFAULT_UNSTABLE) where {T<:Union{Tuple, NamedTuple}}
53+
54+
cols = fieldarrays(c)
55+
x = cols[1]
56+
p = sortperm(x; alg = alg)
57+
if length(cols) > 1
58+
y = cols[2]
59+
refine_perm!(p, cols, 1, x, y, 1, length(x))
60+
end
61+
return p
62+
end
63+
64+
Base.sort!(c::StructArray{<:Union{Tuple, NamedTuple}}) = permute!(c, sortperm(c))
65+
Base.sort(c::StructArray{<:Union{Tuple, NamedTuple}}) = c[sortperm(c)]
66+
67+
# Methods from IndexedTables to refine sorting:
68+
# # assuming x[p] is sorted, sort by remaining columns where x[p] is constant
69+
function refine_perm!(p, cols, c, x, y, lo, hi)
70+
temp = similar(p, 0)
71+
order = Base.Order.By(j->(@inbounds k=y[j]; k))
72+
nc = length(cols)
73+
for (_, idxs) in TiedIndices(x, p, (lo, hi))
74+
i, i1 = extrema(idxs)
75+
if i1 > i
76+
sort_sub_by!(p, i, i1, y, order, temp)
77+
if c < nc-1
78+
z = cols[c+2]
79+
refine_perm!(p, cols, c+1, y, z, i, i1)
80+
end
81+
end
82+
end
83+
end
84+
85+
# sort the values in v[i0:i1] in place, by array `by`
86+
Base.@noinline function sort_sub_by!(v, i0, i1, by, order, temp)
87+
empty!(temp)
88+
sort!(v, i0, i1, MergeSort, order, temp)
89+
end
90+
91+
Base.@noinline function sort_sub_by!(v, i0, i1, by::Vector{T}, order, temp) where T<:Integer
92+
min = max = by[v[i0]]
93+
@inbounds for i = i0+1:i1
94+
val = by[v[i]]
95+
if val < min
96+
min = val
97+
elseif val > max
98+
max = val
99+
end
100+
end
101+
rangelen = max-min+1
102+
n = i1-i0+1
103+
if rangelen <= n
104+
sort_int_range_sub_by!(v, i0-1, n, by, rangelen, min, temp)
105+
else
106+
empty!(temp)
107+
sort!(v, i0, i1, MergeSort, order, temp)
108+
end
109+
v
110+
end
111+
112+
# in-place counting sort of x[ioffs+1:ioffs+n] by values in `by`
113+
function sort_int_range_sub_by!(x, ioffs, n, by, rangelen, minval, temp)
114+
offs = 1 - minval
115+
116+
where = fill(0, rangelen+1)
117+
where[1] = 1
118+
@inbounds for i = 1:n
119+
where[by[x[i+ioffs]] + offs + 1] += 1
120+
end
121+
cumsum!(where, where)
122+
123+
length(temp) < n && resize!(temp, n)
124+
@inbounds for i = 1:n
125+
xi = x[i+ioffs]
126+
label = by[xi] + offs
127+
wl = where[label]
128+
temp[wl] = xi
129+
where[label] = wl+1
130+
end
131+
132+
@inbounds for i = 1:n
133+
x[i+ioffs] = temp[i]
134+
end
135+
x
136+
end
137+

test/runtests.jl

+44-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ end
2222
a = WeakRefStrings.StringVector(["a", "b", "c"])
2323
b = PooledArrays.PooledArray([1, 2, 3])
2424
c = [:a, :b, :c]
25-
@test StructArrays.isdiscrete(a)
26-
@test StructArrays.isdiscrete(b)
27-
@test !StructArrays.isdiscrete(c)
25+
@test !StructArrays.ispooledarray(a)
26+
@test StructArrays.isstringarray(a)
27+
@test StructArrays.ispooledarray(b)
28+
@test !StructArrays.isstringarray(b)
29+
@test !StructArrays.ispooledarray(c)
30+
@test !StructArrays.isstringarray(c)
2831
s = StructArray(a=a, b=b, c=c)
2932
permute!(s, [2, 3, 1])
3033
@test s.a == ["b", "c", "a"]
@@ -36,6 +39,44 @@ end
3639
@test s == t
3740
end
3841

42+
@testset "sortperm" begin
43+
c = StructArray(a=[1,1,2,2], b=[1,2,3,3], c=["a","b","c","d"])
44+
d = StructArray(a=[1,1,2,2], b=[1,2,3,3], c=["a","b","c","d"])
45+
@test issorted(c)
46+
@test sortperm(c) == [1,2,3,4]
47+
permute!(c, [2,3,4,1])
48+
@test c == StructArray(a=[1,2,2,1], b=[2,3,3,1], c=["b","c","d","a"])
49+
@test sortperm(c) == [4,1,2,3]
50+
@test !issorted(c)
51+
@test sort(c) == d
52+
sort!(c)
53+
@test c == d
54+
55+
c = StructArray(a=[1,1,2,2], b=[1,2,3,3], c=PooledArrays.PooledArray(["a","b","c","d"]))
56+
d = StructArray(a=[1,1,2,2], b=[1,2,3,3], c=PooledArrays.PooledArray(["a","b","c","d"]))
57+
@test issorted(c)
58+
@test sortperm(c) == [1,2,3,4]
59+
permute!(c, [2,3,4,1])
60+
@test c == StructArray(a=[1,2,2,1], b=[2,3,3,1], c=PooledArrays.PooledArray(["b","c","d","a"]))
61+
@test sortperm(c) == [4,1,2,3]
62+
@test !issorted(c)
63+
@test sort(c) == d
64+
sort!(c)
65+
@test c == d
66+
end
67+
68+
@testset "iterators" begin
69+
c = [1, 2, 3, 1, 1]
70+
d = StructArrays.tiedindices(c)
71+
@test eltype(d) == Pair{Int, UnitRange{Int}}
72+
s = collect(d)
73+
@test first.(s) == [1, 2, 3]
74+
@test last.(s) == [1:3, 4:4, 5:5]
75+
t = collect(StructArrays.groupindices(c))
76+
@test first.(t) == [1, 2, 3]
77+
@test last.(t) == [[1, 4, 5], [2], [3]]
78+
end
79+
3980
@testset "similar" begin
4081
t = StructArray(a = rand(10), b = rand(Bool, 10))
4182
s = similar(t)

0 commit comments

Comments
 (0)