-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicitly ignore derivatives of argument checks #1492
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1492 +/- ##
==========================================
- Coverage 84.44% 84.24% -0.21%
==========================================
Files 124 124
Lines 7522 7513 -9
==========================================
- Hits 6352 6329 -23
- Misses 1170 1184 +14
Continue to review full report at Codecov.
|
I'm not a big fan of how much ChainRules-specific code is now being included in these functions. What if the |
This was my initial approach but unfortunately it is not sufficient if one does not ignore the whole blocks of checks, including |
I considered this as well but it felt like if it becomes completely general and flexible it would become easier at some point to not use a macro but just implement the checks explicitly. It's also not only used in constructors and IIRC there's at least one case where some variables are unpacked/defined in the So basically my impressoon was that the macro should be able to generate something like ChainRulesCore.ignore_derivatives() do
if some_var # or only if check_args
some_expr
cond1 || throw(ArgumentError(cond1_msg))
cond2 || ...
end
end and it was not immediately clear if and what a short but descriptive syntax could look that would provide a major advantage over implementing the block directly. Do you have an idea how it could look? |
This looks like fixing a problem in the wrong place. |
It is exactly designed for such a use case: https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tips_for_packages.html#Ignoring-gradients-for-certain-expressions And I don't see why/how it should interact with other non-AD packages, it just executes the function in the primal: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/699e61f1539fdba362fff5a1b438fbccf32370f0/src/ignore_derivatives.jl#L26 |
The benchmarks also show that there is no overhead for non-AD use. And the AD performance issues on master are quite problematic for downstream applications but can't be addressed properly by them (ie. without type piracy - and in some cases not even this is possible since the check is performed in the inner constructor). |
@devmotion I can't think of a nice syntax for the macro, and your points are all valid. I guess another approach would be to define a |
I thought about moving the checks to a separate function but this seemed unnecessarily complicated - for To summarize, I think it's correct to address these performance issues in Distributions but it would be good to do it in a more developer friendly and "less special" way. I don't think it's an AD backend issue since I don't think it can be expected in general that AD neglects these checks - without us telling it that it should. I thought about a more convenient macro. Maybe we could use something like |
I modified the @check_args(
D,
@setup(statements...),
(cond₁, message₁),
(cond₂, message₂),
...,
)
A convenience macro that generates AD-compatible checks of arguments for a distribution of
type `D`.
More concretely, it generates the following Julia code:
```julia
ChainRulesCore.ignore_derivatives() do
if check_args
\$(statements...)
cond₁ || throw(ArgumentError(\$(string(D, ": ", message₁))))
cond₂ || throw(ArgumentError(\$(string(D, ": ", message₂))))
...
end
end
```
The `@setup` argument can be elided if no setup code is needed. Moreover, error messages
can be omitted. In this case the message `"the condition \$(cond) is not satisfied."` is
used. I wonder though if it is too surprising that one has to define a boolean variable |
Why not just having a dedicated function check_arguments(checkargs, f::Function) = checkargs && f()
ChainRulesCore.@non_differentiable check_arguments(checkargs, f) This is positively saying what happens ("checking arguments") instead of negatively "can't differentiate the following code"). |
Sure, this could be done but it seems equivalent to what |
We would own |
We can thus depend on a hypothetical
if that is required. |
I don't see a problem with calling ChainRules - it's documented and guarenteed to not affect the primal computation. Its sole purpose is to not have to define separate functions and mark them as non-differentiable if you want to mark parts of a function body as non-differentiable. Can you explain your example? What is |
It is a hypothetical example. There is not only automatic differentiation, but also other things we want to do in an automatic fashion in Julia, for example automatic uncertainty propagation etc. |
Ah, OK, so in your example I don't have a strong preference, it's simple to switch from one approach to the other since they are only called in the macro. If the consensus is to use a custom function, I can just copy the implementation of |
I don't have a strong opinion wrt function vs macro, but from a design perspective I do prefer both options over calling |
I tried to incorporate all comments, the PR should be ready for a proper review. |
@@ -413,7 +413,7 @@ end | |||
function test_special(dist::Type{LKJ}) | |||
@testset "LKJ mode" begin | |||
@test mode(LKJ(5, 1.5)) == mean(LKJ(5, 1.5)) | |||
@test_throws ArgumentError mode( LKJ(5, 0.5) ) | |||
@test_throws DomainError mode( LKJ(5, 0.5) ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change considered breaking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, not according to ColPrac: https://colprac.sciml.ai/#changes-that-are-not-considered-breaking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@mschauer are you happy with the approach in this PR? I added a custom non-differentiable function and generalized the macro. |
@willtebbutt noticed in JuliaGaussianProcesses/AbstractGPs.jl#256 (comment) that the seemingly simple argument checks cause a massive slowdown in AD, specifically with Zygote. Hence this PR ignores derivatives of such checks explicitly.
@willtebbutt's example with
Normal
shows that with this PR AD is very performant, there is basically no overhead compared with the primal and zero allocations, whereas on master it is roughly 60 times slower:On master:
This PR: