diff --git a/Project.toml b/Project.toml index 9d89ca821..91a7b526e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.5" +version = "1.15.6" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index 661308a11..7016acd60 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -43,6 +43,11 @@ Base.:/(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) Base.:/(::Any, x::NotImplemented) = throw(NotImplementedException(x)) Base.:/(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) +# Fix method ambiguity errors (#589) +Base.:/(x::AbstractZero, ::NotImplemented) = x +Base.:/(x::NotImplemented, ::AbstractThunk) = throw(NotImplementedException(x)) +Base.:/(::AbstractThunk, x::NotImplemented) = throw(NotImplementedException(x)) + Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) function Base.zero(::Type{<:NotImplemented}) return throw( diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 8baa006e8..3af284671 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -33,8 +33,13 @@ Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b) Base.:(-)(a::AbstractThunk) = -unthunk(a) Base.:(-)(a::AbstractThunk, b) = unthunk(a) - b Base.:(-)(a, b::AbstractThunk) = a - unthunk(b) +Base.:(-)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) - unthunk(b) Base.:(/)(a::AbstractThunk, b) = unthunk(a) / b Base.:(/)(a, b::AbstractThunk) = a / unthunk(b) +Base.:(/)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) / unthunk(b) + +# Fix method ambiguity issue +Base.:/(a::AbstractZero, ::AbstractThunk) = a Base.real(a::AbstractThunk) = real(unthunk(a)) Base.imag(a::AbstractThunk) = imag(unthunk(a)) diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index ea460dfd5..e113475c1 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -21,7 +21,7 @@ @test ni + ni2 === ni @test ni2 + ni === ni2 - # multiplication and dot product + # multiplication, division, and dot product @test -ni == ni for a in (true, x, thunk) @test ni * a === ni @@ -32,6 +32,7 @@ for a in (NoTangent(), ZeroTangent()) @test ni * a === a @test a * ni === a + @test a / ni === a @test dot(ni, a) === a @test dot(a, ni) === a end @@ -52,8 +53,10 @@ @test_throws E a - ni end @test_throws E ni - ni2 - @test_throws E ni / x - @test_throws E x / ni + for a in (true, x, thunk) + @test_throws E ni / a + @test_throws E a / ni + end @test_throws E ni / ni2 @test_throws E zero(ni) @test_throws E zero(typeof(ni)) diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 68c2cc53a..c799e5c09 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -101,8 +101,15 @@ @test 1 == -@thunk(-1) @test 1 == @thunk(2) - 1 @test 1 == 2 - @thunk(1) + @test 1 == @thunk(2) - @thunk(1) @test 1.0 == @thunk(1) / 1.0 @test 1.0 == 1.0 / @thunk(1) + @test 1 == @thunk(1) / @thunk(1) + + # check method ambiguities (#589) + for a in (ZeroTangent(), NoTangent()) + @test a / @thunk(2) === a + end @test 1 == real(@thunk(1 + 1im)) @test 1 == imag(@thunk(1 + 1im))