Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce sparsity patterns #139

Merged
merged 10 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions benchmark/bench_jogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Pkg.develop(; path=joinpath(@__DIR__, "SparseConnectivityTracerBenchmarks"))
using BenchmarkTools
using SparseConnectivityTracer
using SparseConnectivityTracer: GradientTracer, HessianTracer
using SparseConnectivityTracer: IndexSetGradientPattern, IndexSetHessianPattern
using SparseConnectivityTracer: DuplicateVector, SortedVector, RecursiveSet

SET_TYPES = (BitSet, Set{Int}, DuplicateVector{Int}, RecursiveSet{Int}, SortedVector{Int})
Expand All @@ -19,8 +20,11 @@ suite["OptimizationProblems"] = optbench([:britgas])
for S1 in SET_TYPES
S2 = Set{Tuple{Int,Int}}

G = GradientTracer{S1}
H = HessianTracer{S1,S2}
PG = IndexSetGradientPattern{Int,S1}
PH = IndexSetHessianPattern{Int,S1,S2}

G = GradientTracer{PG}
H = HessianTracer{PH}

suite["Jacobian"]["Global"][nameof(S1)] = jacbench(TracerSparsityDetector(G, H))
suite["Jacobian"]["Local"][nameof(S1)] = jacbench(TracerLocalSparsityDetector(G, H))
Expand Down
1 change: 1 addition & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("settypes/duplicatevector.jl")
include("settypes/recursiveset.jl")
include("settypes/sortedvector.jl")

include("patterns.jl")
include("tracers.jl")
include("exceptions.jl")
include("operators.jl")
Expand Down
8 changes: 5 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
const DEFAULT_CONNECTIVITY_TRACER = ConnectivityTracer{BitSet}
const DEFAULT_GRADIENT_TRACER = GradientTracer{BitSet}
const DEFAULT_HESSIAN_TRACER = HessianTracer{BitSet,Set{Tuple{Int,Int}}}
const DEFAULT_CONNECTIVITY_TRACER = ConnectivityTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}}
}

#==================#
# Enumerate inputs #
Expand Down
14 changes: 7 additions & 7 deletions src/overloads/connectivity_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ end
return connectivity_tracer_1_to_1(ty, is_infl_arg2_zero)
else
i_out = connectivity_tracer_2_to_1_inner(
inputs(tx), inputs(ty), is_infl_arg1_zero, is_infl_arg2_zero
pattern(tx), pattern(ty), is_infl_arg1_zero, is_infl_arg2_zero
)
return T(i_out) # return tracer
end
end

function connectivity_tracer_2_to_1_inner(
sx::S, sy::S, is_infl_arg1_zero::Bool, is_infl_arg2_zero::Bool
) where {S<:AbstractSet{<:Integer}}
px::P, py::P, is_infl_arg1_zero::Bool, is_infl_arg2_zero::Bool
) where {P<:IndexSetGradientPattern}
if is_infl_arg1_zero && is_infl_arg2_zero
return myempty(S)
return myempty(P)
elseif !is_infl_arg1_zero && is_infl_arg2_zero
return sx
return px
elseif is_infl_arg1_zero && !is_infl_arg2_zero
return sy
return py
else
return union(sx, sy) # return set
return P(union(set(px), set(py))) # return pattern
end
end

Expand Down
21 changes: 19 additions & 2 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
end
end

# Called by HessianTracer with AbstractSet
function gradient_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool
) where {P<:IndexSetGradientPattern}
return P(gradient_tracer_1_to_1_inner(set(p), is_der1_zero)) # return pattern
end

# This is only required because it is called by HessianTracer with IndexSetHessianPattern
# Otherwise, we would just have the method on IndexSetGradientPattern above.
function gradient_tracer_1_to_1_inner(
s::S, is_der1_zero::Bool
) where {S<:AbstractSet{<:Integer}}
Expand Down Expand Up @@ -60,12 +67,22 @@ end
return gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
else
g_out = gradient_tracer_2_to_1_inner(
gradient(tx), gradient(ty), is_der1_arg1_zero, is_der1_arg2_zero
pattern(tx), pattern(ty), is_der1_arg1_zero, is_der1_arg2_zero
)
return T(g_out) # return tracer
end
end

function gradient_tracer_2_to_1_inner(
px::P, py::P, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool
) where {P<:IndexSetGradientPattern}
return P(
gradient_tracer_2_to_1_inner(set(px), set(py), is_der1_arg1_zero, is_der1_arg2_zero)
) # return pattern
end

# This is only required because it is called by HessianTracer with IndexSetHessianPattern
# Otherwise, we would just have the method on IndexSetGradientPattern above.
function gradient_tracer_2_to_1_inner(
sx::S, sy::S, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool
) where {S<:AbstractSet{<:Integer}}
Expand Down
70 changes: 34 additions & 36 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
## 1-to-1

@noinline function hessian_tracer_1_to_1(
t::T, is_der1_zero::Bool, is_secondder_zero::Bool
t::T, is_der1_zero::Bool, is_der2_zero::Bool
) where {T<:HessianTracer}
if isemptytracer(t) # TODO: add test
return t
else
g_out, h_out = hessian_tracer_1_to_1_inner(
gradient(t), hessian(t), is_der1_zero, is_secondder_zero
)
return T(g_out, h_out) # return tracer
p_out = hessian_tracer_1_to_1_inner(pattern(t), is_der1_zero, is_der2_zero)
return T(p_out) # return tracer
end
end

function hessian_tracer_1_to_1_inner(
sg::G, sh::H, is_der1_zero::Bool, is_secondder_zero::Bool
) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}}
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
sh_out = if is_der1_zero && is_secondder_zero
myempty(H)
elseif !is_der1_zero && is_secondder_zero
sh_out = if is_der1_zero && is_der2_zero
myempty(SH)
elseif !is_der1_zero && is_der2_zero
sh
elseif is_der1_zero && !is_secondder_zero
union_product!(myempty(H), sg, sg)
elseif is_der1_zero && !is_der2_zero
union_product!(myempty(SH), sg, sg)
else
union_product!(copy(sh), sg, sg)
end
return sg_out, sh_out # return sets
return P(sg_out, sh_out) # return pattern
end

function overload_hessian_1_to_1(M, op)
Expand Down Expand Up @@ -62,54 +62,52 @@ end
tx::T,
ty::T,
is_der1_arg1_zero::Bool,
is_secondder_arg1_zero::Bool,
is_der2_arg1_zero::Bool,
is_der1_arg2_zero::Bool,
is_secondder_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {T<:HessianTracer}
# TODO: add tests for isempty
if tx.isempty && ty.isempty
return tx # empty tracer
elseif ty.isempty
return hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_secondder_arg1_zero)
return hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero)
elseif tx.isempty
return hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_secondder_arg2_zero)
return hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero)
else
g_out, h_out = hessian_tracer_2_to_1_inner(
gradient(tx),
hessian(tx),
gradient(ty),
hessian(ty),
p_out = hessian_tracer_2_to_1_inner(
pattern(tx),
pattern(ty),
is_der1_arg1_zero,
is_secondder_arg1_zero,
is_der2_arg1_zero,
is_der1_arg2_zero,
is_secondder_arg2_zero,
is_der2_arg2_zero,
is_der_cross_zero,
)
return T(g_out, h_out) # return tracer
return T(p_out) # return tracer
end
end

function hessian_tracer_2_to_1_inner(
sgx::G,
shx::H,
sgy::G,
shy::H,
px::P,
py::P,
is_der1_arg1_zero::Bool,
is_secondder_arg1_zero::Bool,
is_der2_arg1_zero::Bool,
is_der1_arg2_zero::Bool,
is_secondder_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}}
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)
sh_out = myempty(H)
sh_out = myempty(SH)
!is_der1_arg1_zero && union!(sh_out, shx) # hessian alpha
!is_der1_arg2_zero && union!(sh_out, shy) # hessian beta
!is_secondder_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_secondder_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_der2_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_der2_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_der_cross_zero && union_product!(sh_out, sgx, sgy) # cross product 1
!is_der_cross_zero && union_product!(sh_out, sgy, sgx) # cross product 2
return sg_out, sh_out # return sets
return P(sg_out, sh_out) # return pattern
end

function overload_hessian_2_to_1(M, op)
Expand Down
16 changes: 8 additions & 8 deletions src/overloads/ifelse_global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
end

## output union on scalar outputs
function output_union(tx::T, ty::T) where {T<:ConnectivityTracer}
return T(union(inputs(tx), inputs(ty)))
function output_union(tx::T, ty::T) where {T<:AbstractTracer}
return T(output_union(pattern(tx), pattern(ty))) # return tracer
end
function output_union(tx::T, ty::T) where {T<:GradientTracer}
return T(union(gradient(tx), gradient(ty)))
function output_union(px::P, py::P) where {P<:IndexSetGradientPattern}
return P(union(set(px), set(py))) # return pattern
end
function output_union(tx::T, ty::T) where {T<:HessianTracer}
g_out = union(gradient(tx), gradient(ty))
h_out = union(hessian(tx), hessian(ty))
return T(g_out, h_out)
function output_union(px::P, py::P) where {P<:IndexSetHessianPattern}
g_out = union(gradient(px), gradient(py))
h_out = union(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end

output_union(tx::AbstractTracer, y) = tx
Expand Down
Loading
Loading