Skip to content

Commit

Permalink
Improve docs, export trace_input
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Apr 1, 2024
1 parent 63560b0 commit b2c3920
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 13 deletions.
14 changes: 8 additions & 6 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@ Documentation for [SparseConnectivityTracer](https://github.com/adrhill/SparseCo
```

## API reference
SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions:
### Interface
```@docs
Tracer
tracer
connectivity
```

The resulting connectivity matrix can be extracted using [`connectivity`](@ref):
### Internals
SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions:
```@docs
connectivity
Tracer
tracer
trace_input
```

or manually from individual [`Tracer`](@ref) outputs:
The following utilities can be used to extract input indices from [`Tracer`](@ref)s:
```@docs
inputs
sortedinputs
Expand Down
4 changes: 3 additions & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ include("conversion.jl")
include("operators.jl")
include("connectivity.jl")

export Tracer, tracer, inputs, sortedinputs
export Tracer, tracer
export trace_input
export inputs, sortedinputs
export connectivity

end # module
38 changes: 38 additions & 0 deletions src/connectivity.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
## Enumerate inputs

"""
trace_input(x)
Enumerates input indices and constructs [`Tracer`](@ref)s.
## Example
```julia-repl
julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> xt = trace_input(x)
3-element Vector{Tracer}:
Tracer(1,)
Tracer(2,)
Tracer(3,)
julia> yt = f(xt)
3-element Vector{Tracer}:
Tracer(1,)
Tracer(1, 2)
Tracer(3,)
```
"""
trace_input(x) = trace_input(x, 1)
trace_input(::Number, i) = tracer(i)
function trace_input(x::AbstractArray, i)
Expand All @@ -12,6 +37,19 @@ end
Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
## Example
```julia-repl
julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> connectivity(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ 1
```
"""
function connectivity(f::Function, x)
xt = trace_input(x)
Expand Down
67 changes: 61 additions & 6 deletions src/tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,56 @@
Number type keeping track of input indices of previous computations.
See also the convenience constructor [`tracer`](@ref).
For a higher-level interface, refer to [`connectivity`](@ref).
## Examples
By enumerating inputs with tracers, we can keep track of input connectivities:
```julia-repl
julia> xt = [tracer(1), tracer(2), tracer(3)]
3-element Vector{Tracer}:
Tracer(1,)
Tracer(2,)
Tracer(3,)
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> yt = f(xt)
3-element Vector{Tracer}:
Tracer(1,)
Tracer(1, 2)
Tracer(3,)
```
This works via operator-overloading, which either keep input connectivities constant,
compute unions or set connectivities to zero:
```julia-repl
julia> x = tracer(1, 2, 3)
Tracer(1, 2, 3)
julia> sin(x)
julia> sin(x) # Most operators don't modify input connectivities.
Tracer(1, 2, 3)
julia> 2 * x^3
Tracer(1, 2, 3)
julia> 0 * x # Note: Tracer is strictly operator overloading...
Tracer(1, 2, 3)
julia> zero(x) # ...this can be overloaded
julia> zero(x) # Tracer is strictly operator overloading...
Tracer()
julia> 0 * x # ...and doesn't look at input values.
Tracer(1, 2, 3)
julia> y = tracer(3, 5)
Tracer(3, 5)
julia> x + y
julia> x + y # Operations on two Tracers construct union sets
Tracer(1, 2, 3, 5)
julia> x ^ y
Tracer(1, 2, 3, 5)
```
[`Tracer`](@ref) also supports random number generation and pre-allocations:
```
julia> M = rand(Tracer, 3, 2)
3×2 Matrix{Tracer}:
Tracer() Tracer()
Expand Down Expand Up @@ -80,6 +103,19 @@ tracer(inds...) = tracer(inds)
inputs(tracer)
Return raw `UInt64` input indices of a [`Tracer`](@ref).
See also [`sortedinputs`](@ref).
## Example
```julia-repl
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)
julia> inputs(t)
3-element Vector{UInt64}:
0x0000000000000004
0x0000000000000002
0x0000000000000001
```
"""
inputs(t::Tracer) = collect(keys(t.inputs.dict))

Expand All @@ -88,6 +124,25 @@ inputs(t::Tracer) = collect(keys(t.inputs.dict))
sortedinputs([T=Int], tracer)
Return sorted input indices of a [`Tracer`](@ref).
See also [`inputs`](@ref).
## Example
```julia
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)
julia> sortedinputs(t)
3-element Vector{Int64}:
1
2
4
julia> sortedinputs(UInt8, t)
3-element Vector{UInt8}:
0x01
0x02
0x04
```
"""
sortedinputs(t::Tracer) = sortedinputs(Int, t)
sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t)))
Expand Down

0 comments on commit b2c3920

Please sign in to comment.