Skip to content

Commit

Permalink
Rename Tracer to ConnectivityTracer
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Apr 23, 2024
1 parent 770fb51 commit 91bddf9
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 55 deletions.
6 changes: 3 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ TracerSparsityDetector
```

## Internals
SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions:
SparseConnectivityTracer works by pushing a `Number` type called [`ConnectivityTracer`](@ref) through generic functions:
```@docs
Tracer
ConnectivityTracer
tracer
trace_input
```

The following utilities can be used to extract input indices from [`Tracer`](@ref)s:
The following utilities can be used to extract input indices from [`ConnectivityTracer`](@ref)s:
```@docs
inputs
```
4 changes: 2 additions & 2 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ include("overload_connectivity.jl")
include("connectivity.jl")
include("adtypes.jl")

export Tracer
export tracer, trace_input
export ConnectivityTracer, connectivitytracer
export trace_input
export inputs
export connectivity
export TracerSparsityDetector
Expand Down
36 changes: 18 additions & 18 deletions src/connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
trace_input(x)
Enumerates input indices and constructs [`Tracer`](@ref)s.
Enumerates input indices and constructs [`ConnectivityTracer`](@ref)s.
## Example
```jldoctest
Expand All @@ -12,23 +12,23 @@ julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> xt = trace_input(x)
3-element Vector{Tracer}:
Tracer(1,)
Tracer(2,)
Tracer(3,)
3-element Vector{ConnectivityTracer}:
ConnectivityTracer(1,)
ConnectivityTracer(2,)
ConnectivityTracer(3,)
julia> yt = f(xt)
3-element Vector{Tracer}:
Tracer(1,)
Tracer(1, 2)
Tracer(3,)
3-element Vector{ConnectivityTracer}:
ConnectivityTracer(1,)
ConnectivityTracer(1, 2)
ConnectivityTracer(3,)
```
"""
trace_input(x) = trace_input(x, 1)
trace_input(::Number, i) = tracer(i)
trace_input(::Number, i) = connectivitytracer(i)
function trace_input(x::AbstractArray, i)
indices = (i - 1) .+ reshape(1:length(x), size(x))
return tracer.(indices)
return connectivitytracer.(indices)
end

## Construct connectivity matrix
Expand Down Expand Up @@ -65,28 +65,28 @@ where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to
"""
function connectivity(f!, y, x)
xt = trace_input(x)
yt = similar(y, Tracer)
yt = similar(y, ConnectivityTracer)
f!(yt, xt)
return _connectivity(xt, yt)
end

_connectivity(xt::Tracer, yt::Number) = _connectivity([xt], [yt])
_connectivity(xt::Tracer, yt::AbstractArray{Number}) = _connectivity([xt], yt)
_connectivity(xt::AbstractArray{Tracer}, yt::Number) = _connectivity(xt, [yt])
function _connectivity(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number})
_connectivity(xt::ConnectivityTracer, yt::Number) = _connectivity([xt], [yt])
_connectivity(xt::ConnectivityTracer, yt::AbstractArray{Number}) = _connectivity([xt], yt)
_connectivity(xt::AbstractArray{ConnectivityTracer}, yt::Number) = _connectivity(xt, [yt])
function _connectivity(xt::AbstractArray{ConnectivityTracer}, yt::AbstractArray{<:Number})
return connectivity_sparsematrixcsc(xt, yt)
end

function connectivity_sparsematrixcsc(
xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}
xt::AbstractArray{ConnectivityTracer}, yt::AbstractArray{<:Number}
)
# Construct connectivity matrix of size (ouput_dim, input_dim)
n, m = length(xt), length(yt)
I = UInt64[]
J = UInt64[]
V = Bool[]
for (i, y) in enumerate(yt)
if y isa Tracer
if y isa ConnectivityTracer
for j in inputs(y)
push!(I, i)
push!(J, j)
Expand Down
26 changes: 13 additions & 13 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z)
@eval Base.$fn(t::Tracer) = t
@eval Base.$fn(t::ConnectivityTracer) = t
end

for fn in ops_1_to_1_const
@eval Base.$fn(::Tracer) = EMPTY_TRACER
@eval Base.$fn(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER
end

for fn in ops_1_to_2
@eval Base.$fn(t::Tracer) = (t, t)
@eval Base.$fn(t::ConnectivityTracer) = (t, t)
end

for fn in ops_2_to_1
@eval Base.$fn(a::Tracer, b::Tracer) = uniontracer(a, b)
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
@eval Base.$fn(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b)
@eval Base.$fn(t::ConnectivityTracer, ::Number) = t
@eval Base.$fn(::Number, t::ConnectivityTracer) = t
end

# Extra types required for exponent
Base.:^(a::Tracer, b::Tracer) = uniontracer(a, b)
Base.:^(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b)
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::Tracer, ::$T) = t
@eval Base.:^(::$T, t::Tracer) = t
@eval Base.:^(t::ConnectivityTracer, ::$T) = t
@eval Base.:^(::$T, t::ConnectivityTracer) = t
end
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::Tracer) = t
Base.:^(t::ConnectivityTracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::ConnectivityTracer) = t

## Rounding
Base.round(t::Tracer, ::RoundingMode; kwargs...) = t
Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t

## Random numbers
rand(::AbstractRNG, ::SamplerType{Tracer}) = EMPTY_TRACER
rand(::AbstractRNG, ::SamplerType{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER
42 changes: 23 additions & 19 deletions src/tracer_connectivity.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,50 @@
"""
Tracer(indexset) <: Number
ConnectivityTracer(indexset) <: Number
Number type keeping track of input indices of previous computations.
See also the convenience constructor [`tracer`](@ref).
For a higher-level interface, refer to [`connectivity`](@ref).
"""
struct Tracer <: Number
struct ConnectivityTracer <: AbstractTracer
inputs::BitSet # indices of connected, enumerated inputs
end

const EMPTY_TRACER = Tracer(BitSet())
const EMPTY_CONNECTIVITY_TRACER = ConnectivityTracer(BitSet())
empty(::Type{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER
empty(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`.
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`.
# When this happens, we create a new empty tracer with no input connectivity.
Tracer(::Number) = EMPTY_TRACER
Tracer(t::Tracer) = t
ConnectivityTracer(::Number) = EMPTY_CONNECTIVITY_TRACER
ConnectivityTracer(t::ConnectivityTracer) = t

uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))
function uniontracer(a::ConnectivityTracer, b::ConnectivityTracer)
return ConnectivityTracer(union(a.inputs, b.inputs))
end

"""
tracer(index)
tracer(indices)
connectivitytracer(index)
connectivitytracer(indices)
Convenience constructor for [`Tracer`](@ref) from input indices.
Convenience constructor for [`ConnectivityTracer`](@ref) from input indices.
"""
tracer(index::Integer) = Tracer(BitSet(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(BitSet(inds))
tracer(inds...) = tracer(inds)
connectivitytracer(index::Integer) = ConnectivityTracer(BitSet(index))
connectivitytracer(inds::NTuple{N,<:Integer}) where {N} = ConnectivityTracer(BitSet(inds))
connectivitytracer(inds...) = connectivitytracer(inds)

# Utilities for accessing input indices
"""
inputs(tracer)
Return raw `UInt64` input indices of a [`Tracer`](@ref) or [`JacobianTracer`](@ref)
Return raw `UInt64` input indices of a [`ConnectivityTracer`](@ref) or [`JacobianTracer`](@ref)
## Example
```jldoctest
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)
julia> t = connectivitytracer(1, 2, 4)
ConnectivityTracer(1, 2, 4)
julia> inputs(t)
3-element Vector{Int64}:
Expand All @@ -49,8 +53,8 @@ julia> inputs(t)
4
```
"""
inputs(t::Tracer) = collect(t.inputs)
inputs(t::ConnectivityTracer) = collect(t.inputs)

function Base.show(io::IO, t::Tracer)
return Base.show_delim_array(io, inputs(t), "Tracer(", ',', ')', true)
function Base.show(io::IO, t::ConnectivityTracer)
return Base.show_delim_array(io, inputs(t), "ConnectivityTracer(", ',', ')', true)
end

0 comments on commit 91bddf9

Please sign in to comment.