diff --git a/src/pattern.jl b/src/pattern.jl index d2df6659..1ea72434 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -16,6 +16,23 @@ function trace_input(::Type{T}, x::AbstractArray, i) where {T<:AbstractTracer} return tracer.(T, indices) end +## Trace function +function trace_function(::Type{T}, f, x) where {T<:AbstractTracer} + xt = trace_input(T, x) + yt = f(xt) + return xt, yt +end + +function trace_function(::Type{T}, f!, y, x) where {T<:AbstractTracer} + xt = trace_input(T, x) + yt = similar(y, T) + f!(yt, xt) + return xt, yt +end + +to_array(x::Number) = [x] +to_array(x::AbstractArray) = x + ## Construct sparsity pattern matrix """ connectivity_pattern(f, x) @@ -40,8 +57,10 @@ julia> connectivity_pattern(f, x) ⋅ ⋅ 1 ``` """ -connectivity_pattern(f, x, settype::Type{S}=DEFAULT_SET_TYPE) where {S} = - pattern(f, ConnectivityTracer{S}, x) +function connectivity_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} + xt, yt = trace_function(ConnectivityTracer{S}, f, x) + return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) +end """ connectivity_pattern(f!, y, x) @@ -53,7 +72,27 @@ where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. """ function connectivity_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - return pattern(f!, y, ConnectivityTracer{S}, x) + xt, yt = trace_function(ConnectivityTracer{S}, f!, y, x) + return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) +end + +function connectivity_pattern_to_mat( + xt::AbstractArray{T}, yt::AbstractArray{<:Number} +) where {T<:ConnectivityTracer} + n, m = length(xt), length(yt) + I = UInt64[] # row indices + J = UInt64[] # column indices + V = Bool[] # values + for (i, y) in enumerate(yt) + if y isa T + for j in inputs(y) + push!(I, i) + push!(J, j) + push!(V, true) + end + end + end + return sparse(I, J, V, m, n) end """ @@ -79,7 +118,8 @@ julia> jacobian_pattern(f, x) ``` """ function jacobian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - return pattern(f, JacobianTracer{S}, x) + xt, yt = trace_function(JacobianTracer{S}, f, x) + return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end """ @@ -91,7 +131,27 @@ Compute the sparsity pattern of the Jacobian of `f!(y, x)`. The type of index set `S` can be specified as an optional argument and defaults to `BitSet`. """ function jacobian_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - return pattern(f!, y, JacobianTracer{S}, x) + xt, yt = trace_function(JacobianTracer{S}, f!, y, x) + return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) +end + +function jacobian_pattern_to_mat( + xt::AbstractArray{T}, yt::AbstractArray{<:Number} +) where {T<:JacobianTracer} + n, m = length(xt), length(yt) + I = UInt64[] # row indices + J = UInt64[] # column indices + V = Bool[] # values + for (i, y) in enumerate(yt) + if y isa T + for j in inputs(y) + push!(I, i) + push!(J, j) + push!(V, true) + end + end + end + return sparse(I, J, V, m, n) end """ @@ -129,63 +189,21 @@ julia> hessian_pattern(g, x) ``` """ function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S} - return pattern(f, HessianTracer{S}, x) -end - -function pattern(f, ::Type{T}, x) where {T<:AbstractTracer} - xt = trace_input(T, x) - yt = f(xt) - return _pattern(xt, yt) -end - -function pattern(f!, y, ::Type{T}, x) where {T<:AbstractTracer} - xt = trace_input(T, x) - yt = similar(y, T) - f!(yt, xt) - return _pattern(xt, yt) -end - -_pattern(xt::AbstractTracer, yt::Number) = _pattern([xt], [yt]) -_pattern(xt::AbstractTracer, yt::AbstractArray{<:Number}) = _pattern([xt], yt) -_pattern(xt::AbstractArray{<:AbstractTracer}, yt::Number) = _pattern(xt, [yt]) -function _pattern(xt::AbstractArray{<:AbstractTracer}, yt::AbstractArray{<:Number}) - return _pattern_to_sparsemat(xt, yt) + xt, yt = trace_function(HessianTracer{S}, f, x) + return hessian_pattern_to_mat(to_array(xt), yt) end -function _pattern_to_sparsemat( - xt::AbstractArray{T}, yt::AbstractArray{<:Number} -) where {T<:AbstractTracer} - # Construct matrix of size (ouput_dim, input_dim) - n, m = length(xt), length(yt) - I = UInt64[] # row indices - J = UInt64[] # column indices - V = Bool[] # values - for (i, y) in enumerate(yt) - if y isa T - for j in inputs(y) - push!(I, i) - push!(J, j) - push!(V, true) - end - end - end - return sparse(I, J, V, m, n) -end - -function _pattern_to_sparsemat( - xt::AbstractArray{HessianTracer{S}}, yt::AbstractArray{HessianTracer{S}} +function hessian_pattern_to_mat( + xt::AbstractArray{HessianTracer{S}}, yt::HessianTracer{S} ) where {S} - length(yt) != 1 && error("pattern(f, HessianTracer, x) expects scalar output y=f(x).") - y = only(yt) - # Allocate Hessian matrix n = length(xt) I = UInt64[] # row indices J = UInt64[] # column indices V = Bool[] # values - for i in keys(y.inputs) - for j in y.inputs[i] + for i in keys(yt.inputs) + for j in yt.inputs[i] push!(I, i) push!(J, j) push!(V, true)