Skip to content

Commit

Permalink
Speed up coloring (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 2, 2024
1 parent 93c9a39 commit 16aa888
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.3.2"
version = "0.3.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
64 changes: 34 additions & 30 deletions DifferentiationInterface/src/sparse/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,24 @@ end
neighbors_of_column(g::BipartiteGraph, j::Integer) = nz_in_col(g.A_colmajor, j)
neighbors_of_row(g::BipartiteGraph, i::Integer) = nz_in_row(g.A_rowmajor, i)

function colored_neighbors_of_column(
g::BipartiteGraph, j::Integer, colors::AbstractVector{<:Integer}
)
return filter(neighbors_of_column(g, j)) do i
!iszero(colors[i])
end
end

function colored_neighbors_of_row(
g::BipartiteGraph, i::Integer, colors::AbstractVector{<:Integer}
)
return filter(neighbors_of_row(g, i)) do j
!iszero(colors[j])
end
end

function distance2_column_coloring(g::BipartiteGraph)
n = length(columns(g))
colors = zeros(Int, n)
forbidden_colors = zeros(Int, n)
for v in columns(g) # default ordering
for w in neighbors_of_column(g, v)
for x in colored_neighbors_of_row(g, w, colors)
forbidden_colors[colors[x]] = v
for x in neighbors_of_row(g, w)
if !iszero(colors[x])
forbidden_colors[colors[x]] = v
end
end
end
for c in columns(g)
if forbidden_colors[c] != v
colors[v] = c
break
end
end
colors[v] = minimum(c for c in columns(g) if forbidden_colors[c] != v)
end
return colors
end
Expand All @@ -82,11 +73,18 @@ function distance2_row_coloring(g::BipartiteGraph)
forbidden_colors = zeros(Int, m)
for v in 1:m # default ordering
for w in neighbors_of_row(g, v)
for x in colored_neighbors_of_column(g, w, colors)
forbidden_colors[colors[x]] = v
for x in neighbors_of_column(g, w)
if !iszero(colors[x])
forbidden_colors[colors[x]] = v
end
end
end
for c in rows(g)
if forbidden_colors[c] != v
colors[v] = c
break
end
end
colors[v] = minimum(c for c in rows(g) if forbidden_colors[c] != v)
end
return colors
end
Expand Down Expand Up @@ -166,21 +164,27 @@ function star_coloring(g::AdjacencyGraph)
if !iszero(colors[w]) # w is colored
forbidden_colors[colors[w]] = v
end
for x in colored_neighbors(g, w, colors)
if iszero(colors[w]) # w is not colored
for x in neighbors(g, w)
if !iszero(colors[x]) && iszero(colors[w]) # w is not colored
forbidden_colors[colors[x]] = v
else
for y in colored_neighbors(g, x, colors)
y != w || continue
if colors[y] == colors[w]
forbidden_colors[colors[x]] = v
break
for y in neighbors(g, x)
if !iszero(colors[y]) && y != w
if colors[y] == colors[w]
forbidden_colors[colors[x]] = v
break
end
end
end
end
end
end
colors[v] = minimum(c for c in columns(g) if forbidden_colors[c] != v)
for c in columns(g)
if forbidden_colors[c] != v
colors[v] = c
break
end
end
end
return colors
end
Expand Down

0 comments on commit 16aa888

Please sign in to comment.