Skip to content

Commit

Permalink
Refactor pattern extraction (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored May 3, 2024
1 parent 8a9823e commit 83cb0b0
Showing 1 changed file with 71 additions and 53 deletions.
124 changes: 71 additions & 53 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ function trace_input(::Type{T}, x::AbstractArray, i) where {T<:AbstractTracer}
return tracer.(T, indices)
end

## Trace function
function trace_function(::Type{T}, f, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = f(xt)
return xt, yt
end

function trace_function(::Type{T}, f!, y, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = similar(y, T)
f!(yt, xt)
return xt, yt
end

to_array(x::Number) = [x]
to_array(x::AbstractArray) = x

## Construct sparsity pattern matrix
"""
connectivity_pattern(f, x)
Expand All @@ -40,8 +57,10 @@ julia> connectivity_pattern(f, x)
⋅ ⋅ 1
```
"""
connectivity_pattern(f, x, settype::Type{S}=DEFAULT_SET_TYPE) where {S} =
pattern(f, ConnectivityTracer{S}, x)
function connectivity_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
xt, yt = trace_function(ConnectivityTracer{S}, f, x)
return connectivity_pattern_to_mat(to_array(xt), to_array(yt))
end

"""
connectivity_pattern(f!, y, x)
Expand All @@ -53,7 +72,27 @@ where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to
The type of index set `S` can be specified as an optional argument and defaults to `BitSet`.
"""
function connectivity_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
return pattern(f!, y, ConnectivityTracer{S}, x)
xt, yt = trace_function(ConnectivityTracer{S}, f!, y, x)
return connectivity_pattern_to_mat(to_array(xt), to_array(yt))
end

function connectivity_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
) where {T<:ConnectivityTracer}
n, m = length(xt), length(yt)
I = UInt64[] # row indices
J = UInt64[] # column indices
V = Bool[] # values
for (i, y) in enumerate(yt)
if y isa T
for j in inputs(y)
push!(I, i)
push!(J, j)
push!(V, true)
end
end
end
return sparse(I, J, V, m, n)
end

"""
Expand All @@ -79,7 +118,8 @@ julia> jacobian_pattern(f, x)
```
"""
function jacobian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
return pattern(f, JacobianTracer{S}, x)
xt, yt = trace_function(JacobianTracer{S}, f, x)
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

"""
Expand All @@ -91,7 +131,27 @@ Compute the sparsity pattern of the Jacobian of `f!(y, x)`.
The type of index set `S` can be specified as an optional argument and defaults to `BitSet`.
"""
function jacobian_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
return pattern(f!, y, JacobianTracer{S}, x)
xt, yt = trace_function(JacobianTracer{S}, f!, y, x)
return jacobian_pattern_to_mat(to_array(xt), to_array(yt))
end

function jacobian_pattern_to_mat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
) where {T<:JacobianTracer}
n, m = length(xt), length(yt)
I = UInt64[] # row indices
J = UInt64[] # column indices
V = Bool[] # values
for (i, y) in enumerate(yt)
if y isa T
for j in inputs(y)
push!(I, i)
push!(J, j)
push!(V, true)
end
end
end
return sparse(I, J, V, m, n)
end

"""
Expand Down Expand Up @@ -129,63 +189,21 @@ julia> hessian_pattern(g, x)
```
"""
function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
return pattern(f, HessianTracer{S}, x)
end

function pattern(f, ::Type{T}, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = f(xt)
return _pattern(xt, yt)
end

function pattern(f!, y, ::Type{T}, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = similar(y, T)
f!(yt, xt)
return _pattern(xt, yt)
end

_pattern(xt::AbstractTracer, yt::Number) = _pattern([xt], [yt])
_pattern(xt::AbstractTracer, yt::AbstractArray{<:Number}) = _pattern([xt], yt)
_pattern(xt::AbstractArray{<:AbstractTracer}, yt::Number) = _pattern(xt, [yt])
function _pattern(xt::AbstractArray{<:AbstractTracer}, yt::AbstractArray{<:Number})
return _pattern_to_sparsemat(xt, yt)
xt, yt = trace_function(HessianTracer{S}, f, x)
return hessian_pattern_to_mat(to_array(xt), yt)
end

function _pattern_to_sparsemat(
xt::AbstractArray{T}, yt::AbstractArray{<:Number}
) where {T<:AbstractTracer}
# Construct matrix of size (ouput_dim, input_dim)
n, m = length(xt), length(yt)
I = UInt64[] # row indices
J = UInt64[] # column indices
V = Bool[] # values
for (i, y) in enumerate(yt)
if y isa T
for j in inputs(y)
push!(I, i)
push!(J, j)
push!(V, true)
end
end
end
return sparse(I, J, V, m, n)
end

function _pattern_to_sparsemat(
xt::AbstractArray{HessianTracer{S}}, yt::AbstractArray{HessianTracer{S}}
function hessian_pattern_to_mat(
xt::AbstractArray{HessianTracer{S}}, yt::HessianTracer{S}
) where {S}
length(yt) != 1 && error("pattern(f, HessianTracer, x) expects scalar output y=f(x).")
y = only(yt)

# Allocate Hessian matrix
n = length(xt)
I = UInt64[] # row indices
J = UInt64[] # column indices
V = Bool[] # values

for i in keys(y.inputs)
for j in y.inputs[i]
for i in keys(yt.inputs)
for j in yt.inputs[i]
push!(I, i)
push!(J, j)
push!(V, true)
Expand Down

0 comments on commit 83cb0b0

Please sign in to comment.