From d86d5ea91e404a12e33f9b6a244bc61f93953f3c Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 8 Jan 2024 15:07:56 +0800 Subject: [PATCH 1/5] Fix configured frule created by non_differentiable rule --- Project.toml | 2 +- src/rule_definition_tools.jl | 4 ++-- test/rule_definition_tools.jl | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index eadf0761c..195cccac5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.19.0" +version = "1.19.1" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 434c5b843..106ef76b3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - @nospecialize(::Any), + @nospecialize(::RuleConfig), $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::Any), $(map(esc, primal_sig_parts)...) + @nospecialize(::RuleConfig), $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 5a177566d..926ed7cd2 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -215,6 +215,24 @@ end @test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent()) end + @testset "interactions with configs" begin + struct AllConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} end + + foo_ndc1(x) = string(x) + @non_differentiable foo_ndc1(x) + @test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent()) + r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0) + @test r1 == string(2.0) + @test pb1(NoTangent()) == (NoTangent(), NoTangent()) + + foo_ndc2(x; y=0) = string(x+y) + @non_differentiable foo_ndc2(x) + @test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) + r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0) + @test r2 == string(6.0) + @test pb2(NoTangent()) == (NoTangent(), NoTangent()) + end + @testset "Not supported (Yet)" begin # Where clauses are not supported. @test_macro_throws( From 1a648791deff706676031ecbf906dd95a6d89911 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 8 Jan 2024 15:36:09 +0800 Subject: [PATCH 2/5] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 926ed7cd2..05c4d8389 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -225,7 +225,7 @@ end @test r1 == string(2.0) @test pb1(NoTangent()) == (NoTangent(), NoTangent()) - foo_ndc2(x; y=0) = string(x+y) + foo_ndc2(x; y=0) = string(x + y) @non_differentiable foo_ndc2(x) @test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0) From 6befc08fe6298e33186c544b930e312618041498 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 9 Jan 2024 12:10:35 +0800 Subject: [PATCH 3/5] remove nospecialize --- src/rule_definition_tools.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 106ef76b3..170798d14 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - @nospecialize(::RuleConfig), + ::$RuleConfig, $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::RuleConfig), $(map(esc, primal_sig_parts)...) + ::$RuleConfig, $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() From 67036a86ab03953f392446b888b04a2ef3b3c096 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 9 Jan 2024 14:43:29 +0800 Subject: [PATCH 4/5] Constrain generated signature for nondiff frule to need tuple first arg so no ambig with ruleconfig first arg --- src/rule_definition_tools.jl | 4 ++-- test/rule_definition_tools.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 170798d14..ef2db6b42 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - ::$RuleConfig, + ::Tuple, $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - ::$RuleConfig, $(map(esc, primal_sig_parts)...) + ::Tuple, $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 05c4d8389..43863a915 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -220,14 +220,14 @@ end foo_ndc1(x) = string(x) @non_differentiable foo_ndc1(x) - @test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent()) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent()) r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0) @test r1 == string(2.0) @test pb1(NoTangent()) == (NoTangent(), NoTangent()) foo_ndc2(x; y=0) = string(x + y) @non_differentiable foo_ndc2(x) - @test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) + @test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent()) r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0) @test r2 == string(6.0) @test pb2(NoTangent()) == (NoTangent(), NoTangent()) From 161e5f14b9bdd3816ed9e2b9f070963315e163b5 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 9 Jan 2024 16:44:33 +0800 Subject: [PATCH 5/5] nospecialize on tuple --- src/rule_definition_tools.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ef2db6b42..10ce7beec 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) function (::Core.kwftype(typeof(ChainRulesCore.frule)))( @nospecialize($kwargs::Any), frule::typeof(ChainRulesCore.frule), - ::Tuple, + @nospecialize(::Tuple), $(map(esc, primal_sig_parts)...), ) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - ::Tuple, $(map(esc, primal_sig_parts)...) + @nospecialize(::Tuple), $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent()