Skip to content

Commit

Permalink
Fixes for patterns introduced in #139
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Jun 27, 2024
1 parent 3f1e186 commit 3748c47
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 142 deletions.
17 changes: 8 additions & 9 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}}
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}},false}
}

#==================#
Expand All @@ -9,20 +9,19 @@ const DEFAULT_HESSIAN_TRACER = HessianTracer{

"""
trace_input(T, x)
trace_input(T, x)
trace_input(T, xs)
Enumerates input indices and constructs the specified type `T` of tracer.
Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref).
"""
trace_input(::Type{T}, x) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, x, 1)
trace_input(::Type{T}, xs) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, xs, 1)

function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}}
return create_tracer(T, x, i)
end
function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:Union{AbstractTracer,Dual}}
indices = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracers(T, xs, indices)
is = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracers(T, xs, is)
end
function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}}
return only(create_tracers(T, [x], [i]))
end

#=========================#
Expand Down
72 changes: 47 additions & 25 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,33 @@
end

function hessian_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool, is_der2_zero::Bool; shared::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
if shared
sh_out = if is_secondder_zero
sh
else
union_product!(sh, sg, sg)
end
sh_out = if is_der1_zero && is_der2_zero
myempty(SH)
elseif !is_der1_zero && is_der2_zero
sh
elseif is_der1_zero && is_der2_zero
union_product!(myempty(SH), sg, sg)
else
sh_out = if is_der1_zero && is_secondder_zero
myempty(H)
elseif !is_der1_zero && is_secondder_zero
sh
elseif is_der1_zero && !is_secondder_zero
union_product!(myempty(H), sg, sg)
else
union_product!(copy(sh), sg, sg)
end
union_product!(copy(sh), sg, sg)
end
return P(sg_out, sh_out) # return pattern
end

function hessian_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
sh_out = if is_der2_zero
sh
else
union_product!(sh, sg, sg)
end
return P(sg_out, sh_out) # return pattern
end
Expand Down Expand Up @@ -104,17 +110,33 @@ function hessian_tracer_2_to_1_inner(
is_der1_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)
if shared
sh_out = union!(shx, shy)
else
sh_out = myempty(SH)
!is_der1_arg1_zero && union!(sh_out, shx) # hessian alpha
!is_der1_arg2_zero && union!(sh_out, shy) # hessian beta
end
sh_out = myempty(SH)
!is_der1_arg1_zero && union!(sh_out, shx) # hessian alpha
!is_der1_arg2_zero && union!(sh_out, shy) # hessian beta
!is_der2_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_der2_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_der_cross_zero && union_product!(sh_out, sgx, sgy) # cross product 1
!is_der_cross_zero && union_product!(sh_out, sgy, sgx) # cross product 2
return P(sg_out, sh_out) # return pattern
end

function hessian_tracer_2_to_1_inner(
px::P,
py::P,
is_der1_arg1_zero::Bool,
is_der2_arg1_zero::Bool,
is_der1_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)
sh_out = union!(shx, shy)
!is_der2_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_der2_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_der_cross_zero && union_product!(sh_out, sgx, sgy) # cross product 1
Expand Down
17 changes: 11 additions & 6 deletions src/overloads/ifelse_global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
function output_union(px::P, py::P) where {P<:IndexSetGradientPattern}
return P(union(set(px), set(py))) # return pattern
end
function output_union(px::P, py::P) where {G,H,shared,P<:IndexSetHessianPattern{G,H,shared}}
function output_union(
px::P, py::P
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}} # non-mutating
g_out = union(gradient(px), gradient(py))
h_out = union(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end
function output_union(
px::P, py::P
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}} # mutating
g_out = union(gradient(px), gradient(py))
h_out = if shared
union!(hessian(tx), hessian(ty))
else
union(hessian(px), hessian(py))
end
h_out = union!(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end

Expand Down
68 changes: 44 additions & 24 deletions src/patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,20 @@ AbstractPattern
├── AbstractGradientPattern: used in GradientTracer
│ └── IndexSetGradientPattern
└── AbstractHessianPattern: used in HessianTracer
└── IndexSetHessianPattern
├── IndexSetHessianPattern
└── SharedIndexSetHessianPattern
```
"""
abstract type AbstractPattern end

"""
isshared(pattern)
Indicates whether patterns share memory (mutate).
"""
isshared(::P) where {P<:AbstractPattern} = isshared(P)
isshared(::Type{P}) where {P<:AbstractPattern} = false

"""
myempty(T)
myempty(tracer)
Expand All @@ -25,13 +34,11 @@ Constructor for an empty tracer or pattern of type `T` representing a new number
myempty

"""
seed(T, i)
seed(tracer, i)
seed(pattern, i)
create_patterns(P, xs, is)
Constructor for a tracer or pattern of type `T` that only contains the given index `i`.
Convenience constructor for patterns of type `P` for multiple inputs `xs` and their indices `is`.
"""
seed
create_patterns

#==========================#
# Utilities on AbstractSet #
Expand Down Expand Up @@ -69,18 +76,17 @@ For use with [`GradientTracer`](@ref).
## Expected interface
* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern)
* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i`
* `gradient(p::MyPattern)`: return non-zero indices `i` for use with `GradientTracer`
Note that besides their names, the last two functions are usually identical.
* [`myempty`](@ref)
* [`create_patterns`](@ref)
* `gradient(p::MyPattern)`: return non-zero indices `i` in the gradient representation
* [`isshared`](@ref) in case the pattern is shared (mutates). Defaults to false.
"""
abstract type AbstractGradientPattern <: AbstractPattern end

"""
$(TYPEDEF)
Vector sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values.
Gradient sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values.
## Fields
$(TYPEDFIELDS)
Expand All @@ -97,8 +103,9 @@ Base.show(io::IO, p::IndexSetGradientPattern) = Base.show(io, set(p))
function myempty(::Type{IndexSetGradientPattern{I,S}}) where {I,S}
return IndexSetGradientPattern{I,S}(myempty(S))
end
function seed(::Type{IndexSetGradientPattern{I,S}}, i) where {I,S}
return IndexSetGradientPattern{I,S}(seed(S, i))
function create_patterns(::Type{P}, xs, is) where {I,S,P<:IndexSetGradientPattern{I,S}}
sets = seed.(S, is)
return P.(sets)
end

# Tracer compatibility
Expand All @@ -118,29 +125,42 @@ For use with [`HessianTracer`](@ref).
## Expected interface
* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern)
* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i` in the first-order representation
* [`myempty`](@ref)
* [`create_patterns`](@ref)
* `gradient(p::MyPattern)`: return non-zero indices `i` in the first-order representation
* `hessian(p::MyPattern)`: return non-zero indices `(i, j)` in the second-order representation
* [`isshared`](@ref) in case the pattern is shared (mutates). Defaults to false.
"""
abstract type AbstractHessianPattern <: AbstractPattern end

"""
IndexSetHessianPattern(vector::AbstractGradientPattern, mat::AbstractMatrixPattern)
$(TYPEDEF)
Hessian sparsity pattern represented by:
* an `AbstractSet` of indices ``i`` of non-zero values representing first-order sparsity
* an `AbstractSet` of index tuples ``(i,j)`` of non-zero values representing second-order sparsity
Gradient and Hessian sparsity patterns constructed by combining two AbstractSets.
## Fields
$(TYPEDFIELDS)
"""
struct IndexSetHessianPattern{I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}} <:
AbstractHessianPattern
struct IndexSetHessianPattern{
I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}},mutating
} <: AbstractHessianPattern
gradient::SG
hessian::SH
end
isshared(::Type{IndexSetHessianPattern{I,SG,SH,true}}) where {I,SG,SH} = true

function myempty(::Type{IndexSetHessianPattern{I,SG,SH}}) where {I,SG,SH}
return IndexSetHessianPattern{I,SG,SH}(myempty(SG), myempty(SH))
function myempty(::Type{P}) where {I,SG,SH,M,P<:IndexSetHessianPattern{I,SG,SH,M}}
return P(myempty(SG), myempty(SH))
end
function seed(::Type{IndexSetHessianPattern{I,SG,SH}}, index) where {I,SG,SH}
return IndexSetHessianPattern{I,SG,SH}(seed(SG, index), myempty(SH))
function create_patterns(
::Type{P}, xs, is
) where {I,SG,SH,M,P<:IndexSetHessianPattern{I,SG,SH,M}}
gradients = seed.(SG, is)
hessian = myempty(SH)
return P.(gradients, Ref(hessian))
end

# Tracer compatibility
Expand Down
70 changes: 13 additions & 57 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,76 +136,32 @@ end
# Utilities #
#===========#

myempty(::T) where {T<:AbstractTracer} = myempty(T)
# isshared(::Type{T}) where {P,T<:GradientTracer{P}} = isshared(P) # no shared AbstractGradientPattern yet
isshared(::Type{T}) where {P,T<:HessianTracer{P}} = isshared(P)

# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this
myempty(::T) where {T<:AbstractTracer} = myempty(T)
# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this
myempty(::Type{T}) where {P,T<:GradientTracer{P}} = T(myempty(P), true)
myempty(::Type{T}) where {P,T<:HessianTracer{P}} = T(myempty(P), true)

seed(::T, i) where {T<:AbstractTracer} = seed(T, i)

# seed(::Type{T}, i) where {P,T<:AbstractTracer{P}} = T(seed(P, i)) # JET complains about this
seed(::Type{T}, i) where {P,T<:GradientTracer{P}} = T(seed(P, i))
seed(::Type{T}, i) where {P,T<:HessianTracer{P}} = T(seed(P, i))

"""
create_tracer(T, index)
Convenience constructor for [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices.
"""
function create_tracer(::Type{T}, ::Real, index::Integer) where {P,T<:AbstractTracer{P}}
return T(seed(P, index))
end

function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T}
return Dual(primal, create_tracer(T, primal, index))
end

function create_tracer(::Type{ConnectivityTracer{I}}, ::Real, index::Integer) where {I}
return ConnectivityTracer{I}(seed(I, index))
end
function create_tracer(::Type{GradientTracer{G}}, ::Real, index::Integer) where {G}
return GradientTracer{G}(seed(G, index))
end
function create_tracer(
::Type{HessianTracer{G,H,shared}}, ::Real, index::Integer
) where {G,H,shared}
return HessianTracer{G,H,shared}(seed(G, index), myempty(H))
end

"""
create_tracers(T, xs, indices)
Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref) from multiple inputs and their indices.
Allows the creation of shared tracer fields (sofar only for the Hessian).
Convenience constructor for [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref)
from multiple inputs `xs` and their indices `is`.
"""
function create_tracers(
::Type{T}, xs::AbstractArray{<:Real,N}, indices::AbstractArray{<:Integer,N}
) where {T<:Union{AbstractTracer,Dual},N}
return create_tracer.(T, xs, indices)
) where {P<:AbstractPattern,T<:AbstractTracer{P},N}
patterns = create_patterns(P, xs, indices)
return T.(patterns)
end

function create_tracers(
::Type{HessianTracer{G,H,true}},
xs::AbstractArray{<:Real,N},
indices::AbstractArray{<:Integer,N},
) where {G,H,N}
sh = myempty(H) # shared
return map(indices) do index
HessianTracer{G,H,true}(seed(G, index), sh)
end
end

function create_tracers(
::Type{Dual{P,HessianTracer{G,H,true}}},
xs::AbstractArray{<:Real,N},
indices::AbstractArray{<:Integer,N},
) where {P<:Real,G,H,N}
sh = myempty(H) # shared
return map(xs, indices) do x, index
Dual(x, HessianTracer{G,H,true}(seed(G, index), sh))
end
::Type{D}, xs::AbstractArray{<:Real,N}, indices::AbstractArray{<:Integer,N}
) where {P,T,D<:Dual{P,T},N}
tracers = create_tracers(T, xs, indices)
return D.(xs, tracers)
end

# Pretty-printing of Dual tracers
Expand Down
2 changes: 1 addition & 1 deletion test/brusselator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using SparseConnectivityTracerBenchmarks.ODE: Brusselator!
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

function test_brusselator(method::AbstractSparsityDetector)
Expand Down
2 changes: 1 addition & 1 deletion test/flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseConnectivityTracer
using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

const INPUT_FLUX = reshape(
Expand Down
2 changes: 1 addition & 1 deletion test/test_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseConnectivityTracer: primal, tracer, isemptytracer
using SparseConnectivityTracer: myempty, name
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

function test_nested_duals(::Type{T}) where {T<:AbstractTracer}
Expand Down
Loading

0 comments on commit 3748c47

Please sign in to comment.