From 8251cb3f9558671fddea913b415f6becbf2bd9e0 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 15 Aug 2024 13:56:41 +0200 Subject: [PATCH] Add direct type conversions on `Dual` (#168) --- src/overloads/conversion.jl | 8 ++++++++ src/patterns.jl | 4 ++-- test/test_constructors.jl | 30 ++++++++++++++++++++++++++++-- test/test_hessian.jl | 7 ++++++- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/overloads/conversion.jl b/src/overloads/conversion.jl index 91f4dd3..ceff378 100644 --- a/src/overloads/conversion.jl +++ b/src/overloads/conversion.jl @@ -50,6 +50,14 @@ function Base.convert(::Type{Dual{P1,T}}, d::Dual{P2,T}) where {P1,P2,T} return Dual(convert(P1, primal(d)), tracer(d)) end +# Explicit type conversions +for T in (:Int, :Integer, :Float64, :Float32) + @eval function Base.$T(d::Dual) + isemptytracer(d) || throw(InexactError(Symbol($T), $T, d)) + $T(primal(d)) + end +end + ## Constants # These are methods defined on types. Methods on variables are in operators.jl Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), myempty(T)) diff --git a/src/patterns.jl b/src/patterns.jl index c79da1a..a07756a 100644 --- a/src/patterns.jl +++ b/src/patterns.jl @@ -143,9 +143,9 @@ function union_product!( return hessian end -#=======================# +#=========================# # AbstractGradientPattern # -#=======================# +#=========================# # For use with GradientTracer. diff --git a/test/test_constructors.jl b/test/test_constructors.jl index cc8b906..fca3eff 100644 --- a/test/test_constructors.jl +++ b/test/test_constructors.jl @@ -1,7 +1,7 @@ # Test construction and conversions of internal tracer types using SparseConnectivityTracer: AbstractTracer, GradientTracer, HessianTracer, Dual -using SparseConnectivityTracer: primal, tracer, isemptytracer -using SparseConnectivityTracer: myempty, name +using SparseConnectivityTracer: primal, tracer, isemptytracer, myempty, name +using SparseConnectivityTracer: IndexSetGradientPattern using Test # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS @@ -200,3 +200,29 @@ end test_similar(T) end end + +@testset "Explicit type conversions on Dual" begin + @testset "$T" for T in union(GRADIENT_TRACERS, HESSIAN_TRACERS) + P = IndexSetGradientPattern{Int,BitSet} + T = GradientTracer{P} + + p = P(BitSet(2)) + t_full = T(p) + t_empty = myempty(T) + d_full = Dual(1.0, t_full) + d_empty = Dual(1.0, t_empty) + + @testset "Non-empty tracer" begin + @testset "$TOUT" for TOUT in (Int, Integer, Float64, Float32) + @test_throws InexactError TOUT(d_full) + end + end + @testset "Empty tracer" begin + @testset "$TOUT" for TOUT in (Int, Integer, Float64, Float32) + out = TOUT(d_empty) + @test out isa TOUT # not a Dual! + @test isone(out) + end + end + end +end diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 2466571..9fb0d5d 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -9,12 +9,17 @@ using NNlib: NNlib # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") +REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int}) -# This exists to be able to quickly run tests in the REPL. +# These exists to be able to quickly run tests in the REPL. # NOTE: H gets overwritten inside the testsets. method = TracerSparsityDetector() H(f, x) = hessian_sparsity(f, x, method) +P = first(HESSIAN_PATTERNS) +T = HessianTracer{P} +D = Dual{Int,T} + @testset "Global Hessian" begin @testset "$P" for P in HESSIAN_PATTERNS T = HessianTracer{P}