Skip to content

Commit ae7b114

Browse files
committed
make @non_differentiable use identical pullbacks when possible
Fixes #678
1 parent fa530b9 commit ae7b114

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

src/rule_definition_tools.jl

+12-7
Original file line numberDiff line numberDiff line change
@@ -418,27 +418,32 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
418418
end
419419
end
420420

421-
function tuple_expression(primal_sig_parts)
421+
function _make_pullback_for_non_differentiable(::Val{N}) where {N}
422+
Vararg{Any,N} # throw early for invalid `N`, must be nonnegative `Int`
423+
function pullback_for_non_differentiable(::Any)
424+
ntuple(Returns(NoTangent()), Val(N))
425+
end
426+
end
427+
428+
function tuple_length_expression(primal_sig_parts)
422429
has_vararg = _isvararg(primal_sig_parts[end])
423430
return if !has_vararg
424431
num_primal_inputs = length(primal_sig_parts)
425-
Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...)
432+
:($num_primal_inputs)
426433
else
427434
num_primal_inputs = length(primal_sig_parts) - 1 # - vararg
428435
length_expr =
429436
:($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end])))))
430-
@strip_linenos :(ntuple(i -> NoTangent(), $length_expr))
437+
@strip_linenos :($length_expr)
431438
end
432439
end
433440

434441
function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
435442
esc_primal_sig_parts = map(esc, primal_sig_parts)
436-
tup_expr = tuple_expression(primal_sig_parts)
443+
tup_len_expr = tuple_length_expression(primal_sig_parts)
437444
primal_name = first(primal_invoke.args)
438445
pullback_expr = @strip_linenos quote
439-
function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_))
440-
return $(tup_expr)
441-
end
446+
_make_pullback_for_non_differentiable(Val{$(tup_len_expr)}())
442447
end
443448

444449
@gensym kwargs

test/rule_definition_tools.jl

+41
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,47 @@ end
4242

4343
@testset "rule_definition_tools.jl" begin
4444
@testset "@non_differentiable" begin
45+
@testset "`_make_pullback_for_non_differentiable`" begin
46+
f = ChainRulesCore._make_pullback_for_non_differentiable
47+
@testset "throws on invalid input" begin
48+
@test_throws Exception f(Val(0.0))
49+
@test_throws Exception f(Val(-1))
50+
end
51+
@testset "identical objects" begin
52+
for i 0:5
53+
v = Val(i)
54+
@test f(v) === f(v)
55+
end
56+
end
57+
@testset "correctness" begin
58+
for i 0:5
59+
expected = ntuple((_ -> NoTangent()), i)
60+
@test f(Val(i))(:arbitrary) === expected
61+
end
62+
end
63+
@testset "dispatch" begin
64+
for i 0:5
65+
pullback = f(Val(i))
66+
@test_throws MethodError pullback()
67+
@test_throws MethodError pullback(1, 2)
68+
end
69+
end
70+
end
71+
72+
@testset "issue #678: identical pullback objects" begin
73+
issue_678_f(::Any) = nothing
74+
issue_678_g(::Any) = nothing
75+
issue_678_h(::Any...) = nothing
76+
@non_differentiable issue_678_f(::Any)
77+
@non_differentiable issue_678_g(::Any)
78+
@non_differentiable issue_678_h(::Any...)
79+
@test (
80+
last(rrule(issue_678_f, 0.1)) ===
81+
last(rrule(issue_678_g, 0.2)) ===
82+
last(rrule(issue_678_h, 0.3))
83+
)
84+
end
85+
4586
@testset "two input one output function" begin
4687
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
4788
@non_differentiable nondiff_2_1(::Any, ::Any)

0 commit comments

Comments
 (0)