diff --git a/Project.toml b/Project.toml index eadf0761c..195cccac5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.19.0" +version = "1.19.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 434c5b843..10ce7beec 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - @nospecialize(::Any), + @nospecialize(::Tuple), $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::Any), $(map(esc, primal_sig_parts)...) + @nospecialize(::Tuple), $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 5a177566d..43863a915 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -215,6 +215,24 @@ end @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) end + @testset "interactions with configs" begin + struct AllConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} end + + foo_ndc1(x) = string(x) + @non_differentiable foo_ndc1(x) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent()) + r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0) + @test r1 == string(2.0) + @test pb1(NoTangent()) == (NoTangent(), NoTangent()) + + foo_ndc2(x; y=0) = string(x + y) + @non_differentiable foo_ndc2(x) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) + r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0) + @test r2 == string(6.0) + @test pb2(NoTangent()) == (NoTangent(), NoTangent()) + end + @testset "Not supported (Yet)" begin # Where clauses are not supported. @test_macro_throws(