Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use ChainRules rrules with autodiff? #583

Closed
maxfreu opened this issue Jan 31, 2023 · 8 comments
Closed

How to use ChainRules rrules with autodiff? #583

maxfreu opened this issue Jan 31, 2023 · 8 comments

Comments

@maxfreu
Copy link

maxfreu commented Jan 31, 2023

I couldn't find a better title, so let me explain: I have a function interpolated via BSplineKit, based on some lookup table. BSplineKit provides rrules for differentiation of the interpolated function. Now I want to use this function in some code, that I want to differentiate with Enzyme, because Zygote is too slow. Is that already possible? My first lazy attempts resulted in segfaults :(

@vchuravy
Copy link
Member

No that is currently not possible. #172 is the issue to follow, but we are not targeting ChainRules support from the get-go, but rather allowing the user to provide rules in an Enzyme compatible fashion.

@vchuravy vchuravy closed this as not planned Won't fix, can't repro, duplicate, stale Jan 31, 2023
@maxfreu
Copy link
Author

maxfreu commented Feb 1, 2023

Thanks for the hint & the development work!

@CarloLucibello
Copy link
Collaborator

I would suggest reopening this, unless there is some big technical blocker. A lot of work has gone into defining a big set of rules into https://github.com/JuliaDiff/ChainRules.jl, it may take some time to EnzymeRules to catch up.

Moreover, if we want to encourage a smooth transition to Enzyme of Flux users and generally of the ML ecosystem (see #805), it would be nice to support the custom rules that people have been writing with ChainRules for years.

@wsmoses
Copy link
Member

wsmoses commented Feb 13, 2024

I'm not opposed to this, but there are several challenges/limitations.

  1. Most ChainRules implicitly assumes that mutation does not occur. For example, consider the A * B rule. The pullback would store A and B (by reference). However, if A is overwritten from forward to reverse, the data an A will have changed and the ChainRule will silently get the answer wrong as a result.
  2. Most rules implemented with ChainRules should not have rules in Enzyme. Because Enzyme doesn't usually need rules for all Julia functions (e.g. working from the lower level up, rules for most code can ben generated from the definition). As a consequence the question is now not whether a function needs a rule, but whether a function should have a rule. If this is the case, it is often (though not always) for performance reasons, which means you probably want the EnzymeRule to be fast.

I'd earlier written ChainRules to EnzymeRules importers for Forward and Reverse Mode in this PR: https://github.com/EnzymeAD/Enzyme.jl/pull/996/files (see import_frule and import_rrule). These may be helpful for a quick conversion, but they have certain limitations, and performance limitations.

Specifically

    import_frule(::fn, tys...)
Automatically import a ChainRules.frule as a custom forward mode EnzymeRule. When called in batch mode, this
will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates,
and slow code, always. Importing the rule from ChainRules is also likely to be slower than writing your own rule,
and may also be slower than not having a rule at all.

Use with caution.

Enzyme.@import_frule(typeof(Base.sort), Any);
x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03];
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx)))
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx)))
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,)))
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,)))
# output
(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]))
(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),)
(var"1" = [0.3, 0.1, 0.2],)
(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2])
    import_rrule(::fn, tys...)
Automatically import a ChainRules.rrule as a custom reverse mode EnzymeRule. When called in batch mode, this
will end up calling the primal multiple times which results in slower code. This macro assumes that the underlying
function to be imported is read-only, and returns a Duplicated or Const object. This macro also assumes that the
inputs permit a .+= operation and that the output has a valid Enzyme.Compiler.make_zero function defined.
Finally, this macro falls back to almost always caching all of the inputs, even if it may not be needed for the
derivative computation.
As a result, this auto importer is also likely to be slower than writing your own rule, and may also be slower
than not having a rule at all.

Use with caution.

Enzyme.@import_rrule(typeof(Base.sort), Any);

I'm potentially okay with it being made into an extension package and marked as deprecated / rules have warnings when run.

cc @vchuravy for your thoughts.

In any case the PR needs a small amount of rebase work before it could be merged, if you'd be interested in helping

@wsmoses wsmoses reopened this Feb 13, 2024
@CarloLucibello
Copy link
Collaborator

An opt-in approach like the one in #996 would be valuable already. I understand the performance limitations but slow is better than not working at all. Related to this, it seems to me that writing an enzyme rule is way more difficult than writing a chainrules' one, so I would really appreciate the possibility of a quick translation, and probably in many cases typical of DL the performance hit would be negligible.

@wsmoses
Copy link
Member

wsmoses commented Feb 13, 2024

@CarloLucibello go for it, reviving the PR would be a welcome contribution.

I'm more skeptical of the negligible performance hit because having to copy and cache a bunch of unnecessary memory would scale with the size of the tensors.

@CarloLucibello
Copy link
Collaborator

@CarloLucibello go for it, reviving the PR would be a welcome contribution.

I'm not familiar with Enzyme's internals (honestly I have little familiarity with Enzyme in general yet), I don't think I can help much. Do you want help with testing it?

@wsmoses
Copy link
Member

wsmoses commented May 13, 2024

With #996 now in we have macros to import both chainrules frules and rrules, am now closing.

@wsmoses wsmoses closed this as completed May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants