Skip to content

Commit b9ea5d7

Browse files
authored
make @non_differentiable use identical pullbacks when possible (#679)
* make `@non_differentiable` use identical pullbacks when possible Fixes #678 * simpler * bump version
1 parent 3da9c1a commit b9ea5d7

4 files changed

+17
-5
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.23.0"
3+
version = "1.24.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/ChainRulesCore.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33
using Base.Meta
44
using LinearAlgebra
5-
using Compat: hasfield, hasproperty, ismutabletype
5+
using Compat: hasfield, hasproperty, ismutabletype, Returns
66

77
export frule, rrule # core function
88
# rule configurations

src/rule_definition_tools.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
436436
tup_expr = tuple_expression(primal_sig_parts)
437437
primal_name = first(primal_invoke.args)
438438
pullback_expr = @strip_linenos quote
439-
function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_))
440-
return $(tup_expr)
441-
end
439+
Returns($(tup_expr))
442440
end
443441

444442
@gensym kwargs

test/rule_definition_tools.jl

+14
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ end
4242

4343
@testset "rule_definition_tools.jl" begin
4444
@testset "@non_differentiable" begin
45+
@testset "issue #678: identical pullback objects" begin
46+
issue_678_f(::Any) = nothing
47+
issue_678_g(::Any) = nothing
48+
issue_678_h(::Any...) = nothing
49+
@non_differentiable issue_678_f(::Any)
50+
@non_differentiable issue_678_g(::Any)
51+
@non_differentiable issue_678_h(::Any...)
52+
@test (
53+
last(rrule(issue_678_f, 0.1)) ===
54+
last(rrule(issue_678_g, 0.2)) ===
55+
last(rrule(issue_678_h, 0.3))
56+
)
57+
end
58+
4559
@testset "two input one output function" begin
4660
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
4761
@non_differentiable nondiff_2_1(::Any, ::Any)

0 commit comments

Comments
 (0)