Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly ignore derivatives of argument checks #1492

Merged
merged 15 commits into from
Jan 31, 2022

Conversation

devmotion
Copy link
Member

@willtebbutt noticed in JuliaGaussianProcesses/AbstractGPs.jl#256 (comment) that the seemingly simple argument checks cause a massive slowdown in AD, specifically with Zygote. Hence this PR ignores derivatives of such checks explicitly.

@willtebbutt's example with Normal shows that with this PR AD is very performant, there is basically no overhead compared with the primal and zero allocations, whereas on master it is roughly 60 times slower:

On master:

julia> using BenchmarkTools, Distributions, Zygote

julia> @benchmark Normal(randn(), rand() + 1)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):  7.032 ns  42.774 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     7.482 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.749 ns ±  0.827 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

        ▄▅▇▇█▆▃▁
  ▁▂▂▄▆█████████▇▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁ ▃
  7.03 ns        Histogram: frequency by time        9.44 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark Zygote._pullback($(Zygote.Context()), Normal, randn(), rand() + 1)
BenchmarkTools.Trial: 10000 samples with 198 evaluations.
 Range (min  max):  435.172 ns  85.840 μs  ┊ GC (min  max):  0.00%  98.94%
 Time  (median):     454.141 ns              ┊ GC (median):     0.00%
 Time  (mean ± σ):   628.774 ns ±  2.690 μs  ┊ GC (mean ± σ):  25.31% ±  5.91%

  ▃██▇▇▅▄▄▃▃▂▂▂▂▁▂▁▁▂▂▁                                        ▂
  ████████████████████████▇▆▅▆▅▆▆▆▆▆▆▇▄▅▆▆▄▄▁▅▄▆▄▃▅▄▅▆▅▄▃▄▄▄▄▄ █
  435 ns        Histogram: log(frequency) by time       797 ns <

 Memory estimate: 848 bytes, allocs estimate: 17.

This PR:

julia> using BenchmarkTools, Distributions, Zygote

julia> @benchmark Normal(randn(), rand() + 1)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):  6.989 ns  40.634 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     7.468 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.734 ns ±  0.801 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

         ▃▆▇█▆▄▂
  ▁▁▁▂▄▆█████████▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▂▂▂▂▁▁▁ ▃
  6.99 ns        Histogram: frequency by time         9.4 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark Zygote._pullback($(Zygote.Context()), Normal, randn(), rand() + 1)
BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range (min  max):  7.010 ns  40.456 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     7.466 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.718 ns ±  0.789 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

        ▃▆███▆▃▁
  ▁▁▂▃▅█████████▆▅▃▃▂▂▂▁▂▁▁▁▂▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁ ▃
  7.01 ns        Histogram: frequency by time        9.45 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

@codecov-commenter
Copy link

codecov-commenter commented Jan 23, 2022

Codecov Report

Merging #1492 (e15030a) into master (02bcbf8) will decrease coverage by 0.20%.
The diff coverage is 76.72%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1492      +/-   ##
==========================================
- Coverage   84.44%   84.24%   -0.21%     
==========================================
  Files         124      124              
  Lines        7522     7513       -9     
==========================================
- Hits         6352     6329      -23     
- Misses       1170     1184      +14     
Impacted Files Coverage Δ
src/utils.jl 64.40% <12.50%> (-35.60%) ⬇️
src/univariate/continuous/beta.jl 70.96% <25.00%> (+0.56%) ⬆️
src/univariate/continuous/erlang.jl 67.64% <33.33%> (+1.93%) ⬆️
src/multivariate/multinomial.jl 85.03% <50.00%> (+1.37%) ⬆️
src/cholesky/lkjcholesky.jl 100.00% <100.00%> (ø)
src/edgeworth.jl 96.49% <100.00%> (+0.06%) ⬆️
src/matrix/lkj.jl 99.17% <100.00%> (-0.03%) ⬇️
src/multivariate/dirichlet.jl 72.81% <100.00%> (+0.48%) ⬆️
src/univariate/continuous/arcsine.jl 88.88% <100.00%> (ø)
src/univariate/continuous/betaprime.jl 93.18% <100.00%> (ø)
... and 55 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 16d4091...e15030a. Read the comment docs.

@sethaxen
Copy link
Contributor

I'm not a big fan of how much ChainRules-specific code is now being included in these functions. What if the @check_args macro included ChainRulesCore.ignore_derivatives internally? Then you would only need to call it for cases where a constructor doesn't use @check_args. Also, what if @check_args was generalized to support multiple conditions and custom error messages, which would I think allow it to be used everywhere @check_args currently isn't used?

@devmotion
Copy link
Member Author

What if the @check_args macro included ChainRulesCore.ignore_derivatives internally? Then you would only need to call it for cases where a constructor doesn't use @check_args.

This was my initial approach but unfortunately it is not sufficient if one does not ignore the whole blocks of checks, including if check_args.

@devmotion
Copy link
Member Author

devmotion commented Jan 23, 2022

Also, what if @check_args was generalized to support multiple conditions and custom error messages, which would I think allow it to be used everywhere @check_args currently isn't used?

I considered this as well but it felt like if it becomes completely general and flexible it would become easier at some point to not use a macro but just implement the checks explicitly. It's also not only used in constructors and IIRC there's at least one case where some variables are unpacked/defined in the check_args block. And it felt a bit weird to include if check_args in the macro which would be necessary as well.

So basically my impressoon was that the macro should be able to generate something like

ChainRulesCore.ignore_derivatives() do
    if some_var # or only if check_args
        some_expr
        cond1 || throw(ArgumentError(cond1_msg))
        cond2 || ...
    end
end

and it was not immediately clear if and what a short but descriptive syntax could look that would provide a major advantage over implementing the block directly.

Do you have an idea how it could look?

@mschauer
Copy link
Member

I'm not a big fan of how much ChainRules-specific code is now being included in these functions.

This looks like fixing a problem in the wrong place.
This would now work for AD, but every other nonstandard evaluation besides AD (particles, uncertainty intervals, things we don’t anticipate) has to deal with ChainRules specific code.

@devmotion
Copy link
Member Author

It is exactly designed for such a use case: https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tips_for_packages.html#Ignoring-gradients-for-certain-expressions And I don't see why/how it should interact with other non-AD packages, it just executes the function in the primal: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/699e61f1539fdba362fff5a1b438fbccf32370f0/src/ignore_derivatives.jl#L26

@devmotion
Copy link
Member Author

devmotion commented Jan 23, 2022

The benchmarks also show that there is no overhead for non-AD use. And the AD performance issues on master are quite problematic for downstream applications but can't be addressed properly by them (ie. without type piracy - and in some cases not even this is possible since the check is performed in the inner constructor).

@sethaxen
Copy link
Contributor

@devmotion I can't think of a nice syntax for the macro, and your points are all valid.

I guess another approach would be to define a _maybe_check_args(::Type{<:Distribution}, check::Bool, ps...) function where all checks are performed for each distribution (each with their own overload). Then mark it globally as ChainRulesCore.@non_differentiable. But I suspect that would only work for ChainRules, since for operator-overloading ADs, one would need to dispatch on the param values, so there would be unresolved ambiguity.

@devmotion
Copy link
Member Author

I thought about moving the checks to a separate function but this seemed unnecessarily complicated - for ignore_derivatives one already defines separate distribution specific functions anyway, without adding more methods to the module and hence without potential ambiguity issues.

To summarize, I think it's correct to address these performance issues in Distributions but it would be good to do it in a more developer friendly and "less special" way. I don't think it's an AD backend issue since I don't think it can be expected in general that AD neglects these checks - without us telling it that it should.

I thought about a more convenient macro. Maybe we could use something like @check_args(Beta, a > zero(a), (b > zero(b)) => "b should be positive!") . This would not include any additional expressions but maybe that's fine for now (even though it's a bit annoying that AD won't ignore them if they are outside of the function...).

@devmotion
Copy link
Member Author

I modified the @check_args macro, inspired by ChainRulesCore.@scalar_rule:

    @check_args(
        D,
        @setup(statements...),
        (cond₁, message₁),
        (cond₂, message₂),
        ...,
    )

A convenience macro that generates AD-compatible checks of arguments for a distribution of
type `D`.

More concretely, it generates the following Julia code:
```julia
ChainRulesCore.ignore_derivatives() do
    if check_args
        \$(statements...)
        cond₁ || throw(ArgumentError(\$(string(D, ": ", message₁))))
        cond₂ || throw(ArgumentError(\$(string(D, ": ", message₂))))
        ...
    end
end
```

The `@setup` argument can be elided if no setup code is needed. Moreover, error messages
can be omitted. In this case the message `"the condition \$(cond) is not satisfied."` is
used.

I wonder though if it is too surprising that one has to define a boolean variable check_args. Maybe its name should be passed in the macro call, i.e., something like @check_args(check_args, D, ...)? On the other hand, this would feel a bit redundant.

@mschauer
Copy link
Member

Why not just having a dedicated function

check_arguments(checkargs, f::Function) = checkargs && f()
ChainRulesCore.@non_differentiable check_arguments(checkargs, f)

This is positively saying what happens ("checking arguments") instead of negatively "can't differentiate the following code").

@devmotion
Copy link
Member Author

Sure, this could be done but it seems equivalent to what ignore_derivatives does. In the current state of the PR ignore_derivatives only shows up in the code of the macro - and similarly the call to a check_args function would only show up there if we don't want to remove the macro and implement all checks explicitly.

@mschauer
Copy link
Member

We would own check_arguments so on the path of standard evaluation we would not call chainrules code at all and we are free to add any methods

@mschauer
Copy link
Member

We can thus depend on a hypothetical UncertaintyCore.jl and define

UncertaintyCore.@deterministic check_arguments(checkargs, f)

if that is required.

@devmotion
Copy link
Member Author

I don't see a problem with calling ChainRules - it's documented and guarenteed to not affect the primal computation. Its sole purpose is to not have to define separate functions and mark them as non-differentiable if you want to mark parts of a function body as non-differentiable.

Can you explain your example? What is @deterministic supposed to do? In any case if the function you proposed would be owned by Distributions it would be impossible to extend it in other packages - it would be type piracy and if the functions would be defined locally, eg with a do block, they would not even be available for dispatch. I expect similar problems if another package would define such a function and we would like to extend it.

@mschauer
Copy link
Member

It is a hypothetical example. There is not only automatic differentiation, but also other things we want to do in an automatic fashion in Julia, for example automatic uncertainty propagation etc.
Now we have seen that declaring the argument check ChainRulesCore.@non_differentiable is necessary for automatic differentiability performance, so we should anticipate the need to declare the argument check also to be ignored for other use cases in other automatic frameworks. For example that it adds no uncertainty in uncertainty propagation or whatever else might coming up later.

@devmotion
Copy link
Member Author

we should anticipate the need to declare the argument check also to be ignored for other use cases in other automatic frameworks. For example that it adds no uncertainty in uncertainty propagation or whatever else might coming up later.

Ah, OK, so in your example @deterministic would be similar to @non_differentiable. I thought you wanted to extend the function in some other package or make it an extension of a function in another package, which would both be impossible I think.

I don't have a strong preference, it's simple to switch from one approach to the other since they are only called in the macro. If the consensus is to use a custom function, I can just copy the implementation of ignore_derivatives and use it instead 🤷‍♂️

@sethaxen
Copy link
Contributor

I don't have a strong opinion wrt function vs macro, but from a design perspective I do prefer both options over calling ChainRules.ignore_derivative directly, and this way those contributing new distributions can just use our internal function/macro without even needing to think about ChainRules and how ADs handle argument checking.

src/utils.jl Show resolved Hide resolved
src/utils.jl Show resolved Hide resolved
@devmotion
Copy link
Member Author

I tried to incorporate all comments, the PR should be ready for a proper review.

@devmotion
Copy link
Member Author

@mschauer @sethaxen I would appreciate if you could review this PR and check if I managed to address your comments and suggestions.

src/utils.jl Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@@ -413,7 +413,7 @@ end
function test_special(dist::Type{LKJ})
@testset "LKJ mode" begin
@test mode(LKJ(5, 1.5)) == mean(LKJ(5, 1.5))
@test_throws ArgumentError mode( LKJ(5, 0.5) )
@test_throws DomainError mode( LKJ(5, 0.5) )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change considered breaking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@devmotion
Copy link
Member Author

@mschauer are you happy with the approach in this PR? I added a custom non-differentiable function and generalized the macro.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants