diff --git a/docs/src/index.md b/docs/src/index.md index b9d2be31..d9a71ae3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -10,18 +10,20 @@ Documentation for [SparseConnectivityTracer](https://github.com/adrhill/SparseCo ``` ## API reference -SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions: +### Interface ```@docs -Tracer -tracer +connectivity ``` -The resulting connectivity matrix can be extracted using [`connectivity`](@ref): +### Internals +SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions: ```@docs -connectivity +Tracer +tracer +trace_input ``` -or manually from individual [`Tracer`](@ref) outputs: +The following utilities can be used to extract input indices from [`Tracer`](@ref)s: ```@docs inputs sortedinputs diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 43da4fa0..b9de0f2e 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -7,7 +7,9 @@ include("conversion.jl") include("operators.jl") include("connectivity.jl") -export Tracer, tracer, inputs, sortedinputs +export Tracer, tracer +export trace_input +export inputs, sortedinputs export connectivity end # module diff --git a/src/connectivity.jl b/src/connectivity.jl index b5095c2c..5ffa6273 100644 --- a/src/connectivity.jl +++ b/src/connectivity.jl @@ -1,4 +1,29 @@ ## Enumerate inputs + +""" + trace_input(x) + +Enumerates input indices and constructs [`Tracer`](@ref)s. + +## Example +```julia-repl +julia> x = rand(3); + +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; + +julia> xt = trace_input(x) +3-element Vector{Tracer}: + Tracer(1,) + Tracer(2,) + Tracer(3,) + +julia> yt = f(xt) +3-element Vector{Tracer}: + Tracer(1,) + Tracer(1, 2) + Tracer(3,) +``` +""" trace_input(x) = trace_input(x, 1) trace_input(::Number, i) = tracer(i) function trace_input(x::AbstractArray, i) @@ -12,6 +37,19 @@ end Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)` where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. + +## Example +```julia-repl +julia> x = rand(3); + +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; + +julia> connectivity(f, x) +3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries: + 1 ⋅ ⋅ + 1 1 ⋅ + ⋅ ⋅ 1 +``` """ function connectivity(f::Function, x) xt = trace_input(x) diff --git a/src/tracer.jl b/src/tracer.jl index 6340ec93..5bf97191 100644 --- a/src/tracer.jl +++ b/src/tracer.jl @@ -4,33 +4,56 @@ Number type keeping track of input indices of previous computations. See also the convenience constructor [`tracer`](@ref). +For a higher-level interface, refer to [`connectivity`](@ref). ## Examples +By enumerating inputs with tracers, we can keep track of input connectivities: +```julia-repl +julia> xt = [tracer(1), tracer(2), tracer(3)] +3-element Vector{Tracer}: + Tracer(1,) + Tracer(2,) + Tracer(3,) + +julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; + +julia> yt = f(xt) +3-element Vector{Tracer}: + Tracer(1,) + Tracer(1, 2) + Tracer(3,) +``` + +This works via operator-overloading, which either keep input connectivities constant, +compute unions or set connectivities to zero: ```julia-repl julia> x = tracer(1, 2, 3) Tracer(1, 2, 3) -julia> sin(x) +julia> sin(x) # Most operators don't modify input connectivities. Tracer(1, 2, 3) julia> 2 * x^3 Tracer(1, 2, 3) -julia> 0 * x # Note: Tracer is strictly operator overloading... -Tracer(1, 2, 3) - -julia> zero(x) # ...this can be overloaded +julia> zero(x) # Tracer is strictly operator overloading... Tracer() +julia> 0 * x # ...and doesn't look at input values. +Tracer(1, 2, 3) + julia> y = tracer(3, 5) Tracer(3, 5) -julia> x + y +julia> x + y # Operations on two Tracers construct union sets Tracer(1, 2, 3, 5) julia> x ^ y Tracer(1, 2, 3, 5) +``` +[`Tracer`](@ref) also supports random number generation and pre-allocations: +``` julia> M = rand(Tracer, 3, 2) 3×2 Matrix{Tracer}: Tracer() Tracer() @@ -80,6 +103,19 @@ tracer(inds...) = tracer(inds) inputs(tracer) Return raw `UInt64` input indices of a [`Tracer`](@ref). +See also [`sortedinputs`](@ref). + +## Example +```julia-repl +julia> t = tracer(1, 2, 4) +Tracer(1, 2, 4) + +julia> inputs(t) +3-element Vector{UInt64}: + 0x0000000000000004 + 0x0000000000000002 + 0x0000000000000001 +``` """ inputs(t::Tracer) = collect(keys(t.inputs.dict)) @@ -88,6 +124,25 @@ inputs(t::Tracer) = collect(keys(t.inputs.dict)) sortedinputs([T=Int], tracer) Return sorted input indices of a [`Tracer`](@ref). +See also [`inputs`](@ref). + +## Example +```julia +julia> t = tracer(1, 2, 4) +Tracer(1, 2, 4) + +julia> sortedinputs(t) +3-element Vector{Int64}: + 1 + 2 + 4 + +julia> sortedinputs(UInt8, t) +3-element Vector{UInt8}: + 0x01 + 0x02 + 0x04 +``` """ sortedinputs(t::Tracer) = sortedinputs(Int, t) sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t)))