Skip to content

Commit 2f2c941

Browse files
authored
Add type promotion rules for NoTangent and ZeroTangent, and add eltype for NoTangent (#682)
* Add promotion rules for ZeroTangent and NoTangent * Make NoTangent have an eltype of itself. * bump version
1 parent 9627bd6 commit 2f2c941

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.24.0"
3+
version = "1.25.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/tangent_types/abstract_zero.jl

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Base.:/(z::AbstractZero, ::Any) = z
3232

3333
Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
3434
# (::Type{T})(::AbstractZero, ::AbstractZero...) where {T<:Number} = zero(T)
35+
Base.promote_rule(T::Type{<:Number}, S::Type{<:AbstractZero}) = T
3536

3637
(::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y)
3738
(::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false)
@@ -92,6 +93,8 @@ end
9293
"""
9394
struct NoTangent <: AbstractZero end
9495

96+
Base.eltype(::Type{NoTangent}) = NoTangent
97+
9598
"""
9699
zero_tangent(primal)
97100

test/tangent_types/abstract_zero.jl

+21
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@
8282
@test convert(Float32, ZeroTangent()) === 0.0f0
8383
@test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im
8484

85+
@test promote_type(ZeroTangent, Bool) == Bool
86+
@test promote_type(Bool, ZeroTangent) == Bool
87+
@test promote_type(ZeroTangent, Int64) == Int64
88+
@test promote_type(Int64, ZeroTangent) == Int64
89+
@test promote_type(ZeroTangent, Float32) == Float32
90+
@test promote_type(Float32, ZeroTangent) == Float32
91+
@test promote_type(ZeroTangent, ComplexF64) == ComplexF64
92+
@test promote_type(ComplexF64, ZeroTangent) == ComplexF64
93+
8594
@test z[1] === z
8695
@test z[1:3] === z
8796
@test z[1, 2] === z
@@ -110,6 +119,18 @@
110119
@test dot(dne, 17.2) == dne
111120
@test dot(11.9, dne) == dne
112121

122+
@test eltype(dne) === NoTangent
123+
@test eltype(NoTangent) === NoTangent
124+
125+
@test promote_type(NoTangent, Bool) == Bool
126+
@test promote_type(Bool, NoTangent) == Bool
127+
@test promote_type(NoTangent, Int64) == Int64
128+
@test promote_type(Int64, NoTangent) == Int64
129+
@test promote_type(NoTangent, Float32) == Float32
130+
@test promote_type(Float32, NoTangent) == Float32
131+
@test promote_type(NoTangent, ComplexF64) == ComplexF64
132+
@test promote_type(ComplexF64, NoTangent) == ComplexF64
133+
113134
@test ZeroTangent() + dne == dne
114135
@test dne + ZeroTangent() == dne
115136
@test ZeroTangent() - dne == dne

0 commit comments

Comments
 (0)