Skip to content

Commit 16aa888

Browse files
authored
Speed up coloring (#242)
1 parent 93c9a39 commit 16aa888

File tree

2 files changed

+35
-31
lines changed

2 files changed

+35
-31
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/src/sparse/coloring.jl

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,33 +45,24 @@ end
4545
neighbors_of_column(g::BipartiteGraph, j::Integer) = nz_in_col(g.A_colmajor, j)
4646
neighbors_of_row(g::BipartiteGraph, i::Integer) = nz_in_row(g.A_rowmajor, i)
4747

48-
function colored_neighbors_of_column(
49-
g::BipartiteGraph, j::Integer, colors::AbstractVector{<:Integer}
50-
)
51-
return filter(neighbors_of_column(g, j)) do i
52-
!iszero(colors[i])
53-
end
54-
end
55-
56-
function colored_neighbors_of_row(
57-
g::BipartiteGraph, i::Integer, colors::AbstractVector{<:Integer}
58-
)
59-
return filter(neighbors_of_row(g, i)) do j
60-
!iszero(colors[j])
61-
end
62-
end
63-
6448
function distance2_column_coloring(g::BipartiteGraph)
6549
n = length(columns(g))
6650
colors = zeros(Int, n)
6751
forbidden_colors = zeros(Int, n)
6852
for v in columns(g) # default ordering
6953
for w in neighbors_of_column(g, v)
70-
for x in colored_neighbors_of_row(g, w, colors)
71-
forbidden_colors[colors[x]] = v
54+
for x in neighbors_of_row(g, w)
55+
if !iszero(colors[x])
56+
forbidden_colors[colors[x]] = v
57+
end
58+
end
59+
end
60+
for c in columns(g)
61+
if forbidden_colors[c] != v
62+
colors[v] = c
63+
break
7264
end
7365
end
74-
colors[v] = minimum(c for c in columns(g) if forbidden_colors[c] != v)
7566
end
7667
return colors
7768
end
@@ -82,11 +73,18 @@ function distance2_row_coloring(g::BipartiteGraph)
8273
forbidden_colors = zeros(Int, m)
8374
for v in 1:m # default ordering
8475
for w in neighbors_of_row(g, v)
85-
for x in colored_neighbors_of_column(g, w, colors)
86-
forbidden_colors[colors[x]] = v
76+
for x in neighbors_of_column(g, w)
77+
if !iszero(colors[x])
78+
forbidden_colors[colors[x]] = v
79+
end
80+
end
81+
end
82+
for c in rows(g)
83+
if forbidden_colors[c] != v
84+
colors[v] = c
85+
break
8786
end
8887
end
89-
colors[v] = minimum(c for c in rows(g) if forbidden_colors[c] != v)
9088
end
9189
return colors
9290
end
@@ -166,21 +164,27 @@ function star_coloring(g::AdjacencyGraph)
166164
if !iszero(colors[w]) # w is colored
167165
forbidden_colors[colors[w]] = v
168166
end
169-
for x in colored_neighbors(g, w, colors)
170-
if iszero(colors[w]) # w is not colored
167+
for x in neighbors(g, w)
168+
if !iszero(colors[x]) && iszero(colors[w]) # w is not colored
171169
forbidden_colors[colors[x]] = v
172170
else
173-
for y in colored_neighbors(g, x, colors)
174-
y != w || continue
175-
if colors[y] == colors[w]
176-
forbidden_colors[colors[x]] = v
177-
break
171+
for y in neighbors(g, x)
172+
if !iszero(colors[y]) && y != w
173+
if colors[y] == colors[w]
174+
forbidden_colors[colors[x]] = v
175+
break
176+
end
178177
end
179178
end
180179
end
181180
end
182181
end
183-
colors[v] = minimum(c for c in columns(g) if forbidden_colors[c] != v)
182+
for c in columns(g)
183+
if forbidden_colors[c] != v
184+
colors[v] = c
185+
break
186+
end
187+
end
184188
end
185189
return colors
186190
end

0 commit comments

Comments
 (0)