Skip to content

Commit

Permalink
Add direct type conversions on Dual (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Aug 15, 2024
1 parent 4a7ba7e commit 8251cb3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/overloads/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T}
return Dual(convert(P1, primal(d)), tracer(d))
end

# Explicit type conversions
for T in (:Int, :Integer, :Float64, :Float32)
@eval function Base.$T(d::Dual)
isemptytracer(d) || throw(InexactError(Symbol($T), $T, d))
$T(primal(d))
end
end

## Constants
# These are methods defined on types. Methods on variables are in operators.jl
Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), myempty(T))
Expand Down
4 changes: 2 additions & 2 deletions src/patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ function union_product!(
return hessian
end

#=======================#
#=========================#
# AbstractGradientPattern #
#=======================#
#=========================#

# For use with GradientTracer.

Expand Down
30 changes: 28 additions & 2 deletions test/test_constructors.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Test construction and conversions of internal tracer types
using SparseConnectivityTracer: AbstractTracer, GradientTracer, HessianTracer, Dual
using SparseConnectivityTracer: primal, tracer, isemptytracer
using SparseConnectivityTracer: myempty, name
using SparseConnectivityTracer: primal, tracer, isemptytracer, myempty, name
using SparseConnectivityTracer: IndexSetGradientPattern
using Test

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
Expand Down Expand Up @@ -200,3 +200,29 @@ end
test_similar(T)
end
end

@testset "Explicit type conversions on Dual" begin
@testset "$T" for T in union(GRADIENT_TRACERS, HESSIAN_TRACERS)
P = IndexSetGradientPattern{Int,BitSet}
T = GradientTracer{P}

p = P(BitSet(2))
t_full = T(p)
t_empty = myempty(T)
d_full = Dual(1.0, t_full)
d_empty = Dual(1.0, t_empty)

@testset "Non-empty tracer" begin
@testset "$TOUT" for TOUT in (Int, Integer, Float64, Float32)
@test_throws InexactError TOUT(d_full)
end
end
@testset "Empty tracer" begin
@testset "$TOUT" for TOUT in (Int, Integer, Float64, Float32)
out = TOUT(d_empty)
@test out isa TOUT # not a Dual!
@test isone(out)
end
end
end
end
7 changes: 6 additions & 1 deletion test/test_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ using NNlib: NNlib

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")
REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})

# This exists to be able to quickly run tests in the REPL.
# These exists to be able to quickly run tests in the REPL.
# NOTE: H gets overwritten inside the testsets.
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

P = first(HESSIAN_PATTERNS)
T = HessianTracer{P}
D = Dual{Int,T}

@testset "Global Hessian" begin
@testset "$P" for P in HESSIAN_PATTERNS
T = HessianTracer{P}
Expand Down

0 comments on commit 8251cb3

Please sign in to comment.