Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user-defined set types #31

Merged
merged 9 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseConnectivityTracer"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "0.2.1"
version = "0.3.0-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

Fast Jacobian and Hessian sparsity detection via operator-overloading.

> [!WARNING]
> This package is in early development. Expect frequent breaking changes and refer to the stable documentation.

## Installation
To install this package, open the Julia REPL and run

Expand All @@ -28,7 +31,7 @@ julia> x = rand(3);

julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];

julia> pattern(f, JacobianTracer, x)
julia> pattern(f, JacobianTracer{BitSet}, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
Expand All @@ -43,7 +46,7 @@ julia> x = rand(28, 28, 3, 1);

julia> layer = Conv((3, 3), 3 => 8);

julia> pattern(layer, JacobianTracer, x)
julia> pattern(layer, JacobianTracer{BitSet}, x)
5408×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 146016 stored entries:
⎡⠙⢦⡀⠀⠀⠘⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠀⎤
⎢⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⠀⠈⠻⣦⡀⠀⎥
Expand Down Expand Up @@ -76,7 +79,7 @@ julia> x = rand(5);

julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];

julia> pattern(f, HessianTracer, x)
julia> pattern(f, HessianTracer{BitSet}, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
Expand All @@ -86,7 +89,7 @@ julia> pattern(f, HessianTracer, x)

julia> g(x) = f(x) + x[2]^x[5];

julia> pattern(g, HessianTracer, x)
julia> pattern(g, HessianTracer{BitSet}, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
Expand Down
24 changes: 16 additions & 8 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,24 @@ julia> ADTypes.hessian_sparsity(f, rand(4), TracerSparsityDetector())
⋅ ⋅ ⋅ 1
```
"""
struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end

function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector)
return pattern(f, JacobianTracer, x)
struct TracerSparsityDetector{S<:AbstractIndexSet} <: ADTypes.AbstractSparsityDetector end
TracerSparsityDetector(::Type{S}) where {S<:AbstractIndexSet} = TracerSparsityDetector{S}()
TracerSparsityDetector() = TracerSparsityDetector(BitSet)

function ADTypes.jacobian_sparsity(
f, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f, JacobianTracer{S}, x)
end

function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector)
return pattern(f!, y, JacobianTracer, x)
function ADTypes.jacobian_sparsity(
f!, y, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f!, y, JacobianTracer{S}, x)
end

function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector)
return pattern(f, HessianTracer, x)
function ADTypes.hessian_sparsity(
f, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f, HessianTracer{S}, x)
end
43 changes: 23 additions & 20 deletions src/conversion.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
## Type conversions
for T in (:JacobianTracer, :ConnectivityTracer, :HessianTracer)
@eval Base.promote_rule(::Type{$T}, ::Type{N}) where {N<:Number} = $T
@eval Base.promote_rule(::Type{N}, ::Type{$T}) where {N<:Number} = $T
for TT in (:JacobianTracer, :ConnectivityTracer, :HessianTracer)
@eval Base.promote_rule(::Type{T}, ::Type{N}) where {T<:$TT,N<:Number} = T
@eval Base.promote_rule(::Type{N}, ::Type{T}) where {T<:$TT,N<:Number} = T

@eval Base.big(::Type{$T}) = $T
@eval Base.widen(::Type{$T}) = $T
@eval Base.widen(t::$T) = t
@eval Base.big(::Type{T}) where {T<:$TT} = T
@eval Base.widen(::Type{T}) where {T<:$TT} = T
@eval Base.widen(t::T) where {T<:$TT} = t

@eval Base.convert(::Type{$T}, x::Number) = empty($T)
@eval Base.convert(::Type{$T}, t::$T) = t
@eval Base.convert(::Type{<:Number}, t::$T) = t
@eval Base.convert(::Type{T}, x::Number) where {T<:$TT} = empty(T)
@eval Base.convert(::Type{T}, t::T) where {T<:$TT} = t
@eval Base.convert(::Type{<:Number}, t::T) where {T<:$TT} = t

## Constants
@eval Base.zero(::Type{$T}) = empty($T)
@eval Base.one(::Type{$T}) = empty($T)
@eval Base.typemin(::Type{$T}) = empty($T)
@eval Base.typemax(::Type{$T}) = empty($T)
@eval Base.zero(::Type{T}) where {T<:$TT} = empty(T)
@eval Base.one(::Type{T}) where {T<:$TT} = empty(T)
@eval Base.typemin(::Type{T}) where {T<:$TT} = empty(T)
@eval Base.typemax(::Type{T}) where {T<:$TT} = empty(T)

## Array constructors
@eval Base.similar(a::Array{$T,1}) = zeros($T, size(a, 1))
@eval Base.similar(a::Array{$T,2}) = zeros($T, size(a, 1), size(a, 2))
@eval Base.similar(a::Array{A,1}, ::Type{$T}) where {A} = zeros($T, size(a, 1))
@eval Base.similar(a::Array{A,2}, ::Type{$T}) where {A} = zeros($T, size(a, 1), size(a, 2))
@eval Base.similar(::Array{$T}, m::Int) = zeros($T, m)
@eval Base.similar(::Array, ::Type{$T}, dims::Dims{N}) where {N} = zeros($T, dims)
@eval Base.similar(::Array{$T}, dims::Dims{N}) where {N} = zeros($T, dims)
@eval Base.similar(a::Array{T,1}) where {T<:$TT} = zeros(T, size(a, 1))
@eval Base.similar(a::Array{T,2}) where {T<:$TT} = zeros(T, size(a, 1), size(a, 2))
@eval Base.similar(a::Array{A,1}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1))
@eval Base.similar(a::Array{A,2}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1), size(a, 2))
@eval Base.similar(::Array{T}, m::Int) where {T<:$TT} = zeros(T, m)
@eval Base.similar(::Array{T}, dims::Dims{N}) where {N,T<:$TT} = zeros(T, dims)

@eval Base.similar(
::Array, ::Type{$TT{S}}, dims::Dims{N}
) where {N,S<:AbstractIndexSet} = zeros($TT{S}, dims)
end
4 changes: 2 additions & 2 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z)
end

for fn in ops_1_to_1_const
@eval Base.$fn(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER
@eval Base.$fn(::T) where {T<:ConnectivityTracer} = empty(T)
end

for fn in ops_1_to_2
Expand All @@ -28,4 +28,4 @@ Base.:^(::Irrational{:ℯ}, t::ConnectivityTracer) = t
Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t

## Random numbers
rand(::AbstractRNG, ::SamplerType{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER
rand(::AbstractRNG, ::SamplerType{T}) where {T<:ConnectivityTracer} = empty(T)
20 changes: 10 additions & 10 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ for fn in ops_1_to_1_f
end

for fn in union(ops_1_to_1_z, ops_1_to_1_const)
@eval Base.$fn(::HessianTracer) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::T) where {T<:HessianTracer} = empty(T)
end

## 2-to-1
Expand Down Expand Up @@ -86,31 +86,31 @@ end
for fn in ops_2_to_1_szz
@eval Base.$fn(t::HessianTracer, ::HessianTracer) = promote_order(t)
@eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t)
@eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::Number, t::T) where {T<:HessianTracer} = empty(T)
end

for fn in ops_2_to_1_zsz
@eval Base.$fn(::HessianTracer, t::HessianTracer) = promote_order(t)
@eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::T, ::Number) where {T<:HessianTracer} = empty(T)
@eval Base.$fn(::Number, t::HessianTracer) = promote_order(t)
end

for fn in ops_2_to_1_fzz
@eval Base.$fn(t::HessianTracer, ::HessianTracer) = t
@eval Base.$fn(t::HessianTracer, ::Number) = t
@eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::Number, t::T) where {T<:HessianTracer} = empty(T)
end

for fn in ops_2_to_1_zfz
@eval Base.$fn(::HessianTracer, t::HessianTracer) = t
@eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::T, ::Number) where {T<:HessianTracer} = empty(T)
@eval Base.$fn(::Number, t::HessianTracer) = t
end

for fn in ops_2_to_1_zzz
@eval Base.$fn(::HessianTracer, t::HessianTracer) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::T, t::T) where {T<:HessianTracer} = empty(T)
@eval Base.$fn(::T, ::Number) where {T<:HessianTracer} = empty(T)
@eval Base.$fn(::Number, t::T) where {T<:HessianTracer} = empty(T)
end

# Extra types required for exponent
Expand All @@ -122,7 +122,7 @@ Base.:^(t::HessianTracer, ::Irrational{:ℯ}) = promote_order(t)
Base.:^(::Irrational{:ℯ}, t::HessianTracer) = promote_order(t)

## Rounding
Base.round(t::HessianTracer, ::RoundingMode; kwargs...) = EMPTY_HESSIAN_TRACER
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = empty(T)

## Random numbers
rand(::AbstractRNG, ::SamplerType{HessianTracer}) = EMPTY_HESSIAN_TRACER
rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = empty(T)
22 changes: 11 additions & 11 deletions src/overload_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ for fn in union(ops_1_to_1_s, ops_1_to_1_f)
end

for fn in union(ops_1_to_1_z, ops_1_to_1_const)
@eval Base.$fn(::JacobianTracer) = EMPTY_JACOBIAN_TRACER
@eval Base.$fn(::T) where {T<:JacobianTracer} = empty(T)
end

for fn in union(
Expand All @@ -23,33 +23,33 @@ end

for fn in union(ops_2_to_1_zsz, ops_2_to_1_zfz)
@eval Base.$fn(::JacobianTracer, t::JacobianTracer) = t
@eval Base.$fn(::JacobianTracer, ::Number) = EMPTY_JACOBIAN_TRACER
@eval Base.$fn(::T, ::Number) where {T<:JacobianTracer} = empty(T)
@eval Base.$fn(::Number, t::JacobianTracer) = t
end
for fn in union(ops_2_to_1_szz, ops_2_to_1_fzz)
@eval Base.$fn(t::JacobianTracer, ::JacobianTracer) = t
@eval Base.$fn(t::JacobianTracer, ::Number) = t
@eval Base.$fn(::Number, t::JacobianTracer) = EMPTY_JACOBIAN_TRACER
@eval Base.$fn(::Number, ::T) where {T<:JacobianTracer} = empty(T)
end
for fn in ops_2_to_1_zzz
@eval Base.$fn(::JacobianTracer, ::JacobianTracer) = EMPTY_JACOBIAN_TRACER
@eval Base.$fn(::JacobianTracer, ::Number) = EMPTY_JACOBIAN_TRACER
@eval Base.$fn(::Number, ::JacobianTracer) = EMPTY_JACOBIAN_TRACER
@eval Base.$fn(::T, ::T) where {T<:JacobianTracer} = empty(T)
@eval Base.$fn(::T, ::Number) where {T<:JacobianTracer} = empty(T)
@eval Base.$fn(::Number, ::T) where {T<:JacobianTracer} = empty(T)
end

for fn in union(ops_1_to_2_ss, ops_1_to_2_sf, ops_1_to_2_fs, ops_1_to_2_ff)
@eval Base.$fn(t::JacobianTracer) = (t, t)
end

for fn in union(ops_1_to_2_sz, ops_1_to_2_fz)
@eval Base.$fn(t::JacobianTracer) = (t, EMPTY_JACOBIAN_TRACER)
@eval Base.$fn(t::T) where {T<:JacobianTracer} = (t, empty(T))
end

for fn in union(ops_1_to_2_zs, ops_1_to_2_zf)
@eval Base.$fn(t::JacobianTracer) = (EMPTY_JACOBIAN_TRACER, t)
@eval Base.$fn(t::T) where {T<:JacobianTracer} = (empty(T), t)
end
for fn in ops_1_to_2_zz
@eval Base.$fn(::JacobianTracer) = (EMPTY_JACOBIAN_TRACER, EMPTY_JACOBIAN_TRACER)
@eval Base.$fn(::T) where {T<:JacobianTracer} = (empty(T), empty(T))
end

# Extra types required for exponent
Expand All @@ -61,7 +61,7 @@ Base.:^(t::JacobianTracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::JacobianTracer) = t

## Rounding
Base.round(t::JacobianTracer, ::RoundingMode; kwargs...) = EMPTY_JACOBIAN_TRACER
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:JacobianTracer} = empty(T)

## Random numbers
rand(::AbstractRNG, ::SamplerType{JacobianTracer}) = EMPTY_JACOBIAN_TRACER
rand(::AbstractRNG, ::SamplerType{T}) where {T<:JacobianTracer} = empty(T)
Loading
Loading