Skip to content

Commit

Permalink
Remove ConnectivityTracer and legacy API (#140)
Browse files Browse the repository at this point in the history
* Remove `ConnectivityTracer`

* Remove legacy interface

* Add CHANGELOG: this is our first actually breaking release in a while
  • Loading branch information
adrhill authored Jun 27, 2024
1 parent f662af1 commit ac94586
Show file tree
Hide file tree
Showing 25 changed files with 132 additions and 883 deletions.
44 changes: 44 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SparseConnectivityTracer.jl

## Version `v0.6.0`
* ![BREAKING][badge-breaking] Remove `ConnectivityTracer` ([#140][pr-140])
* ![BREAKING][badge-breaking] Remove legacy interface ([#140][pr-140])
* instead of `jacobian_pattern(f, x)`, use `jacobian_sparsity(f, x, TracerSparsityDetector())`
* instead of `hessian_pattern(f, x)`, use `hessian_sparsity(f, x, TracerSparsityDetector())`
* instead of `local_jacobian_pattern(f, x)`, use `jacobian_sparsity(f, x, TracerLocalSparsityDetector())`
* instead of `local_hessian_pattern(f, x)`, use `hessian_sparsity(f, x, TracerLocalSparsityDetector())`
* ![Bugfix][badge-bugfix] Remove overloads on `similar` to reduce amount of invalidations ([#132][pr-132])
* ![Enhancement][badge-enhancement] Add array overloads ([#131][pr-131])
* ![Enhancement][badge-enhancement] Generalize sparsity pattern representations ([#139][pr-139], [#119][pr-119])
* ![Enhancement][badge-enhancement] Reduce allocations of new tracers ([#128][pr-128])
* ![Enhancement][badge-enhancement] Reduce compile times ([#119][pr-119])

[pr-140]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/140
[pr-139]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/139
[pr-132]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/132
[pr-131]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/131
[pr-128]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/128
[pr-126]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/126
[pr-119]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/119

<!--
# Badges
![BREAKING][badge-breaking]
![Deprecation][badge-deprecation]
![Feature][badge-feature]
![Enhancement][badge-enhancement]
![Bugfix][badge-bugfix]
![Experimental][badge-experimental]
![Maintenance][badge-maintenance]
![Documentation][badge-docs]
-->

[badge-breaking]: https://img.shields.io/badge/BREAKING-red.svg
[badge-deprecation]: https://img.shields.io/badge/deprecation-orange.svg
[badge-feature]: https://img.shields.io/badge/feature-green.svg
[badge-enhancement]: https://img.shields.io/badge/enhancement-blue.svg
[badge-bugfix]: https://img.shields.io/badge/bugfix-purple.svg
[badge-security]: https://img.shields.io/badge/security-black.svg
[badge-experimental]: https://img.shields.io/badge/experimental-lightgrey.svg
[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg
[badge-docs]: https://img.shields.io/badge/docs-orange.svg
43 changes: 27 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

Fast Jacobian and Hessian sparsity detection via operator-overloading.

> [!WARNING]
> This package is in early development. Expect frequent breaking changes and refer to the stable documentation.
## Installation
To install this package, open the Julia REPL and run

Expand All @@ -21,17 +18,19 @@ julia> ]add SparseConnectivityTracer
## Examples
### Jacobian

For functions `y = f(x)` and `f!(y, x)`, the sparsity pattern of the Jacobian of $f$ can be obtained
by computing a single forward-pass through `f`:
For functions `y = f(x)` and `f!(y, x)`, the sparsity pattern of the Jacobian can be obtained
by computing a single forward-pass through the function:

```julia-repl
julia> using SparseConnectivityTracer
julia> detector = TracerSparsityDetector();
julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> jacobian_pattern(f, x)
julia> jacobian_sparsity(f, x, detector)
3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
Expand All @@ -43,11 +42,13 @@ As a larger example, let's compute the sparsity pattern from a convolutional lay
```julia-repl
julia> using SparseConnectivityTracer, Flux
julia> detector = TracerSparsityDetector();
julia> x = rand(28, 28, 3, 1);
julia> layer = Conv((3, 3), 3 => 2);
julia> jacobian_pattern(layer, x)
julia> jacobian_sparsity(layer, x, detector)
1352×2352 SparseArrays.SparseMatrixCSC{Bool, Int64} with 36504 stored entries:
⎡⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠻⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤
⎢⠀⠀⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
Expand All @@ -64,7 +65,7 @@ julia> jacobian_pattern(layer, x)
⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠛⢿⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣄⎦
```

The type of index set `S` that is internally used to keep track of connectivity can be specified via `jacobian_pattern(f, x, S)`, defaulting to `BitSet`.
The type of index set `S` that is internally used to keep track of connectivity can be specified via `jacobian_sparsity(f, x, S)`, defaulting to `BitSet`.
For high-dimensional functions, `Set{Int64}` can be more efficient .

### Hessian
Expand All @@ -77,7 +78,7 @@ julia> x = rand(5);
julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];
julia> hessian_pattern(f, x)
julia> hessian_sparsity(f, x, detector)
5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
Expand All @@ -87,7 +88,7 @@ julia> hessian_pattern(f, x)
julia> g(x) = f(x) + x[2]^x[5];
julia> hessian_pattern(g, x)
julia> hessian_sparsity(g, x, detector)
5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
Expand All @@ -100,30 +101,40 @@ For more detailled examples, take a look at the [documentation](https://adrianhi

### Local tracing

The functions `jacobian_pattern`, `hessian_pattern` and `connectivity_pattern` return conservative sparsity patterns over the entire input domain of `x`.
They are not compatible with functions that require information about the primal values of a computation (e.g. `iszero`, `>`, `==`).
`TracerSparsityDetector` returns conservative sparsity patterns over the entire input domain of `x`.
It is not compatible with functions that require information about the primal values of a computation (e.g. `iszero`, `>`, `==`).

To compute a less conservative sparsity pattern at an input point `x`, use `local_jacobian_pattern`, `local_hessian_pattern` and `local_connectivity_pattern` instead.
Note that these patterns depend on the input `x`:
To compute a less conservative sparsity pattern at an input point `x`, use `TracerLocalSparsityDetector` instead.
Note that patterns computed with `TracerLocalSparsityDetector` depend on the input `x`:

```julia-repl
julia> using SparseConnectivityTracer
julia> detector = TracerLocalSparsityDetector();
julia> f(x) = ifelse(x[2] < x[3], x[1] ^ x[2], x[3] * x[4]);
julia> local_hessian_pattern(f, [1 2 3 4])
julia> hessian_sparsity(f, [1 2 3 4], detector)
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries:
1 1 ⋅ ⋅
1 1 ⋅ ⋅
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅
julia> local_hessian_pattern(f, [1 3 2 4])
julia> hessian_sparsity(f, [1 3 2 4], detector)
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries:
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1
⋅ ⋅ 1 ⋅
```

## ADTypes.jl compatibility
SparseConnectivityTracer uses [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s interface for [sparsity detection](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector),
making it compatible with [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)'s [sparse automatic differentiation](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/tutorial2/) functionality.

In fact, the functions `jacobian_sparsity` and `hessian_sparsity` are re-exported from ADTypes.

## Related packages
* [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl
* [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996)
45 changes: 5 additions & 40 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,28 @@ CollapsedDocStrings = true

## ADTypes Interface

For package developers, we recommend using the [ADTypes.jl](https://github.com/SciML/ADTypes.jl) interface.
SparseConnectivityTracer uses [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s interface for [sparsity detection](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector).
In fact, the functions `jacobian_sparsity` and `hessian_sparsity` are re-exported from ADTypes.

To compute global sparsity patterns of `f(x)` over the entire input domain `x`, use
To compute **global** sparsity patterns of `f(x)` over the entire input domain `x`, use
```@docs
TracerSparsityDetector
```

To compute local sparsity patterns of `f(x)` at a specific input `x`, use
To compute **local** sparsity patterns of `f(x)` at a specific input `x`, use
```@docs
TracerLocalSparsityDetector
```

## Legacy Interface

### Global sparsity

The following functions can be used to compute global sparsity patterns of `f(x)` over the entire input domain `x`.

```@docs
connectivity_pattern
jacobian_pattern
hessian_pattern
```

[`TracerSparsityDetector`](@ref) is the ADTypes equivalent of these functions.

### Local sparsity

The following functions can be used to compute local sparsity patterns of `f(x)` at a specific input `x`.
Note that these patterns are sparser than global patterns but need to be recomputed when `x` changes.

```@docs
local_connectivity_pattern
local_jacobian_pattern
local_hessian_pattern
```

[`TracerLocalSparsityDetector`](@ref) is the ADTypes equivalent of these functions.

## Internals

!!! warning
Internals may change without warning in a future release of SparseConnectivityTracer.

SparseConnectivityTracer works by pushing `Real` number types called tracers through generic functions.
Currently, three tracer types are provided:
Currently, two tracer types are provided:

```@docs
SparseConnectivityTracer.ConnectivityTracer
SparseConnectivityTracer.GradientTracer
SparseConnectivityTracer.HessianTracer
```
Expand All @@ -69,11 +42,3 @@ which keeps track of the primal computation and allows tracing through compariso
```@docs
SparseConnectivityTracer.Dual
```

We also define alternative pseudo-set types that can deliver faster `union`:

```@docs
SparseConnectivityTracer.DuplicateVector
SparseConnectivityTracer.RecursiveSet
SparseConnectivityTracer.SortedVector
```
2 changes: 0 additions & 2 deletions ext/SparseConnectivityTracerNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ ops_1_to_1_s = (

for op in ops_1_to_1_s
T = typeof(op)
@eval SCT.is_infl_zero_global(::$T) = false
@eval SCT.is_der1_zero_global(::$T) = false
@eval SCT.is_der2_zero_global(::$T) = false
end
Expand Down Expand Up @@ -69,7 +68,6 @@ ops_1_to_1_f = (

for op in ops_1_to_1_f
T = typeof(op)
@eval SCT.is_infl_zero_global(::$T) = false
@eval SCT.is_der1_zero_global(::$T) = false
@eval SCT.is_der2_zero_global(::$T) = true
end
Expand Down
1 change: 0 additions & 1 deletion ext/SparseConnectivityTracerSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ ops_1_to_1_s = (

for op in ops_1_to_1_s
T = typeof(op)
@eval SCT.is_infl_zero_global(::$T) = false
@eval SCT.is_der1_zero_global(::$T) = false
@eval SCT.is_der2_zero_global(::$T) = false
end
Expand Down
10 changes: 3 additions & 7 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SparseConnectivityTracer

using ADTypes: ADTypes
using ADTypes: ADTypes, jacobian_sparsity, hessian_sparsity
using Compat: Returns
using SparseArrays: SparseArrays
using SparseArrays: sparse
Expand All @@ -25,7 +25,6 @@ include("exceptions.jl")
include("operators.jl")

include("overloads/conversion.jl")
include("overloads/connectivity_tracer.jl")
include("overloads/gradient_tracer.jl")
include("overloads/hessian_tracer.jl")
include("overloads/ifelse_global.jl")
Expand All @@ -36,13 +35,10 @@ include("overloads/arrays.jl")
include("interface.jl")
include("adtypes.jl")

export connectivity_pattern, local_connectivity_pattern
export jacobian_pattern, local_jacobian_pattern
export hessian_pattern, local_hessian_pattern

# ADTypes interface
export TracerSparsityDetector
export TracerLocalSparsityDetector
# Reexport ADTypes interface
export jacobian_sparsity, hessian_sparsity

function __init__()
@static if !isdefined(Base, :get_extension)
Expand Down
20 changes: 10 additions & 10 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ julia> using ADTypes, SparseConnectivityTracer
julia> f(x) = x[1] + x[2]*x[3] + 1/x[4];
julia> ADTypes.hessian_sparsity(f, rand(4), TracerSparsityDetector())
julia> hessian_sparsity(f, rand(4), TracerSparsityDetector())
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅
Expand All @@ -46,15 +46,15 @@ function TracerSparsityDetector(;
end

function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector{TG,TH}) where {TG,TH}
return jacobian_pattern(f, x, TG)
return _jacobian_sparsity(f, x, TG)
end

function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector{TG,TH}) where {TG,TH}
return jacobian_pattern(f!, y, x, TG)
return _jacobian_sparsity(f!, y, x, TG)
end

function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector{TG,TH}) where {TG,TH}
return hessian_pattern(f, x, TH)
return _hessian_sparsity(f, x, TH)
end

"""
Expand All @@ -72,13 +72,13 @@ julia> using ADTypes, SparseConnectivityTracer
julia> f(x) = x[1] > x[2] ? x[1:3] : x[2:4];
julia> ADTypes.jacobian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector())
julia> jacobian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector())
3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
⋅ 1 ⋅ ⋅
⋅ ⋅ 1 ⋅
⋅ ⋅ ⋅ 1
julia> ADTypes.jacobian_sparsity(f, [2.0, 1.0, 3.0, 4.0], TracerLocalSparsityDetector())
julia> jacobian_sparsity(f, [2.0, 1.0, 3.0, 4.0], TracerLocalSparsityDetector())
3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
1 ⋅ ⋅ ⋅
⋅ 1 ⋅ ⋅
Expand All @@ -90,7 +90,7 @@ julia> using ADTypes, SparseConnectivityTracer
julia> f(x) = x[1] + max(x[2], x[3]) * x[3] + 1/x[4];
julia> ADTypes.hessian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector())
julia> hessian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector())
4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries:
⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅
Expand All @@ -113,15 +113,15 @@ function TracerLocalSparsityDetector(;
end

function ADTypes.jacobian_sparsity(f, x, ::TracerLocalSparsityDetector{TG,TH}) where {TG,TH}
return local_jacobian_pattern(f, x, TG)
return _local_jacobian_sparsity(f, x, TG)
end

function ADTypes.jacobian_sparsity(
f!, y, x, ::TracerLocalSparsityDetector{TG,TH}
) where {TG,TH}
return local_jacobian_pattern(f!, y, x, TG)
return _local_jacobian_sparsity(f!, y, x, TG)
end

function ADTypes.hessian_sparsity(f, x, ::TracerLocalSparsityDetector{TG,TH}) where {TG,TH}
return local_hessian_pattern(f, x, TH)
return _local_hessian_sparsity(f, x, TH)
end
8 changes: 1 addition & 7 deletions src/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@ function Base.showerror(io::IO, e::MissingPrimalError)
println(io, "Function ", e.fn, " requires primal value(s).")
print(
io,
"A dual-number tracer for local sparsity detection can be used via `",
str_local_pattern_fn(e.tracer),
"`.",
"A dual-number tracer for local sparsity detection can be used via `TracerLocalSparsityDetector`.",
)
return nothing
end

str_local_pattern_fn(::ConnectivityTracer) = "local_connectivity_pattern"
str_local_pattern_fn(::GradientTracer) = "local_jacobian_pattern"
str_local_pattern_fn(::HessianTracer) = "local_hessian_pattern"
Loading

0 comments on commit ac94586

Please sign in to comment.