From a78c8200e14c8c91def2eae24405b42c78a36e76 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 18 Jul 2024 10:29:00 +0200 Subject: [PATCH] Smarter symmetric decompression (#370) --- DifferentiationInterface/Project.toml | 2 +- .../src/DifferentiationInterface.jl | 6 ++++-- .../src/first_order/jacobian.jl | 4 ---- .../src/second_order/hessian.jl | 2 -- DifferentiationInterface/src/sparse/hessian.jl | 18 +++++++++++------- .../src/sparse/jacobian.jl | 4 ---- 6 files changed, 16 insertions(+), 20 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 61f98721c..9d5d80f5e 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -61,7 +61,7 @@ PolyesterForwardDiff = "0.1.1" ReverseDiff = "1.15.1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0" -SparseMatrixColorings = "0.3.2" +SparseMatrixColorings = "0.3.5" Symbolics = "5.27.1" Tapir = "0.2.4" Tracker = "0.2.33" diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index ad4a1a4ce..8eb746d07 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -12,7 +12,7 @@ module DifferentiationInterface using ADTypes: ADTypes, AbstractADType using ADTypes: mode, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode using ADTypes: AutoSparse, dense_ad -using ADTypes: coloring_algorithm, column_coloring, row_coloring, symmetric_coloring +using ADTypes: coloring_algorithm, column_coloring, row_coloring using ADTypes: sparsity_detector, jacobian_sparsity, hessian_sparsity using ADTypes: AutoChainRules, @@ -42,7 +42,9 @@ using SparseMatrixColorings: decompress_rows, decompress_rows!, decompress_symmetric, - decompress_symmetric! + decompress_symmetric!, + symmetric_coloring_detailed, + StarSet abstract type Extras end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 802f7354f..f389da679 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -288,9 +288,7 @@ function jacobian_aux!( batched_seeds[a], pushforward_batched_extras_same, ) - end - for a in eachindex(batched_results) for b in eachindex(batched_results[a].elements) copyto!( view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), @@ -320,9 +318,7 @@ function jacobian_aux!( batched_seeds[a], pullback_batched_extras_same, ) - end - for a in eachindex(batched_results) for b in eachindex(batched_results[a].elements) copyto!( view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 6eb785f95..3fde302e0 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -143,9 +143,7 @@ function hessian!( hvp_batched!( f, batched_results[a], backend, x, batched_seeds[a], hvp_batched_extras_same ) - end - for a in eachindex(batched_results) for b in eachindex(batched_results[a].elements) copyto!( view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/src/sparse/hessian.jl index 869aecb2b..e7b1c6690 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/src/sparse/hessian.jl @@ -3,6 +3,7 @@ struct SparseHessianExtras{ } <: HessianExtras sparsity::S colors::Vector{Int} + star_set::StarSet groups::Vector{Vector{Int}} compressed::C batched_seeds::Vector{Batch{B,D}} @@ -14,6 +15,7 @@ end function SparseHessianExtras{B}(; sparsity::S, colors, + star_set, groups, compressed::C, batched_seeds::Vector{Batch{B,D}}, @@ -25,6 +27,7 @@ function SparseHessianExtras{B}(; return SparseHessianExtras{B,S,C,D,R,E2,E1}( sparsity, colors, + star_set, groups, compressed, batched_seeds, @@ -40,7 +43,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} dense_backend = dense_ad(backend) initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend)) sparsity = col_major(initial_sparsity) - colors = symmetric_coloring(sparsity, coloring_algorithm(backend)) + colors, star_set = symmetric_coloring_detailed(sparsity, coloring_algorithm(backend)) groups = color_groups(colors) Ng = length(groups) B = pick_batchsize(maybe_outer(dense_backend), Ng) @@ -57,6 +60,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} return SparseHessianExtras{B}(; sparsity, colors, + star_set, groups, compressed, batched_seeds, @@ -67,8 +71,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} end function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) where {F,B} - @compat (; sparsity, compressed, colors, groups, batched_seeds, hvp_batched_extras) = - extras + @compat (; + sparsity, compressed, colors, star_set, groups, batched_seeds, hvp_batched_extras + ) = extras dense_backend = dense_ad(backend) Ng = length(groups) @@ -85,7 +90,7 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) w if Ng < size(compressed, 2) compressed = compressed[:, 1:Ng] end - return decompress_symmetric(sparsity, compressed, colors) + return decompress_symmetric(sparsity, compressed, colors, star_set) end function hessian!( @@ -95,6 +100,7 @@ function hessian!( sparsity, compressed, colors, + star_set, groups, batched_seeds, batched_results, @@ -116,9 +122,7 @@ function hessian!( batched_seeds[a], hvp_batched_extras_same, ) - end - for a in eachindex(batched_results) for b in eachindex(batched_results[a].elements) copyto!( view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng), @@ -127,7 +131,7 @@ function hessian!( end end - decompress_symmetric!(hess, sparsity, compressed, colors) + decompress_symmetric!(hess, sparsity, compressed, colors, star_set) return hess end diff --git a/DifferentiationInterface/src/sparse/jacobian.jl b/DifferentiationInterface/src/sparse/jacobian.jl index d9a491e35..4dee60cac 100644 --- a/DifferentiationInterface/src/sparse/jacobian.jl +++ b/DifferentiationInterface/src/sparse/jacobian.jl @@ -291,9 +291,7 @@ function sparse_jacobian_aux!( batched_seeds[a], pushforward_batched_extras_same, ) - end - for a in eachindex(batched_results) for b in eachindex(batched_results[a].elements) copyto!( view(compressed, :, 1 + ((a - 1) * B + (b - 1)) % Ng), @@ -334,9 +332,7 @@ function sparse_jacobian_aux!( batched_seeds[a], pullback_batched_extras_same, ) - end - for a in eachindex(batched_results) for b in eachindex(batched_results[a].elements) copyto!( view(compressed, 1 + ((a - 1) * B + (b - 1)) % Ng, :),