Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further ChainRulesCore.rrule Integration #254

Merged
merged 68 commits into from
Oct 8, 2024

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Sep 13, 2024

At present the extent to which ChainRulesCore.rrules can be (straightforwardly) used in this package is quite limited. The purpose of this PR is to:

  • make it easier to incorporate more interesting rules, thus closing Improving ChainRules rule usage #191 ,
  • include some failing tests to ensure that we get clear errors when conversations between Tapir and ChainRules tangent types goes wrong,
  • explain the limitations of what can be done,
  • provide guidance to the situations in which you should consider importing an rrule from ChainRules, and
  • import rules from NNlib.jl and test that everything works NNlib Support #171 .
  • reduce quantity of code duplication between rrule wrapper with / without kwargs.
  • improve macro for use with kwargs rrules.
  • improve macro with where terms so that diagonal dispatch can be done.
  • improve chain rules interop documentation.
  • add + test rules associated with LuxLib.jl
  • integration test Lux.jl
  • add section on quick ways to write rules to docs (method overlays, and getting rules from rrules, simple_zero_adjoint).
  • turn docstrings into doctests in @from_rrule macro

The reason for incorporating the rules from NNlib.jl in this PR, rather than a subsequent one, is that the main place we wish to import rules from in the near term is NNlib.jl. Consequently, doing that in this PR will give us a good sense of whether the new functionality we introduce here is sufficient.

edit: as a happy consequence of thinking more carefully about how we incorporate functionality from ChainRules, I've removed a lot of code from the the @from_rrule macro, and put it in a regular Julia function. This makes this bit of the codebase much more straightforward to understand.

Copy link

codecov bot commented Sep 13, 2024

Codecov Report

Attention: Patch coverage is 89.26174% with 16 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/MooncakeLuxLibExt.jl 57.14% 12 Missing ⚠️
src/interpreter/ir_utils.jl 85.00% 3 Missing ⚠️
src/tools_for_rules.jl 98.90% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/codual.jl 91.17% <ø> (-0.50%) ⬇️
src/interpreter/abstract_interpretation.jl 80.00% <100.00%> (+1.05%) ⬆️
src/interpreter/s2s_reverse_mode_ad.jl 92.97% <100.00%> (+0.01%) ⬆️
src/rrules/avoiding_non_differentiable_code.jl 100.00% <ø> (ø)
src/rrules/blas.jl 98.26% <ø> (ø)
src/rrules/builtins.jl 99.09% <100.00%> (+0.84%) ⬆️
src/rrules/fastmath.jl 100.00% <100.00%> (ø)
src/rrules/foreigncall.jl 94.52% <ø> (+0.51%) ⬆️
src/rrules/misc.jl 100.00% <ø> (+2.53%) ⬆️
... and 4 more

Copy link
Contributor

github-actions bot commented Sep 13, 2024

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │    112.0 │     1.0 │         5.5 │     1.9 │
│                  _sum_1000 │     7.54 │  1360.0 │        32.9 │   0.106 │
│               sum_sin_1000 │     2.46 │    1.62 │        10.9 │    1.04 │
│              _sum_sin_1000 │     2.93 │   292.0 │        16.0 │    1.44 │
│                   kron_sum │     59.4 │    9.91 │       188.0 │    11.3 │
│              kron_view_sum │     77.1 │    10.5 │       204.0 │    12.4 │
│      naive_map_sin_cos_exp │      3.2 │ missing │        8.97 │     2.8 │
│            map_sin_cos_exp │     4.78 │    1.76 │        7.61 │    3.44 │
│      broadcast_sin_cos_exp │     4.32 │    2.57 │         1.7 │    2.88 │
│                 simple_mlp │     8.41 │    3.19 │        15.3 │    3.27 │
│                     gp_lml │     14.9 │    4.35 │     missing │ missing │
│ turing_broadcast_benchmark │     7.72 │ missing │        37.9 │ missing │
│         large_single_block │      3.9 │  4110.0 │        31.1 │    2.24 │
└────────────────────────────┴──────────┴─────────┴─────────────┴─────────┘

@willtebbutt willtebbutt self-assigned this Oct 3, 2024
@willtebbutt willtebbutt requested a review from yebai October 3, 2024 09:20
@willtebbutt
Copy link
Member Author

@yebai could you let me know what you think about the docs?

Copy link
Contributor

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good overall. Below are a few minor improvement suggestions for overlay and a variant of @zero_adjoint with default context.

I didn't spend much time checking the code, though. Perhaps @sunxd3 and @mhauru can do a pass over the code.

ext/MooncakeDynamicPPLExt.jl Show resolved Hide resolved
ext/MooncakeNNlibExt.jl Outdated Show resolved Hide resolved
src/tools_for_rules.jl Show resolved Hide resolved
src/tools_for_rules.jl Show resolved Hide resolved
src/tools_for_rules.jl Show resolved Hide resolved
Copy link
Contributor

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me!

Copy link
Contributor

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand most of the code, but I read the docstrings and skimmed the implementations and tests. Spotted a few typos and had a few questions, but no significant concerns to raise.

docs/src/tools_for_rules.md Outdated Show resolved Hide resolved
src/interpreter/ir_utils.jl Show resolved Hide resolved
src/tools_for_rules.jl Show resolved Hide resolved
src/tools_for_rules.jl Show resolved Hide resolved
src/tools_for_rules.jl Outdated Show resolved Hide resolved
docs/src/tools_for_rules.md Outdated Show resolved Hide resolved
test/tools_for_rules.jl Outdated Show resolved Hide resolved
@willtebbutt
Copy link
Member Author

I've downgraded CI to run on 1.10 for the time being so we can merge this while I figure out the upgrades for 1.11.

@willtebbutt willtebbutt merged commit d6110f0 into main Oct 8, 2024
16 of 17 checks passed
@willtebbutt willtebbutt deleted the wct/actually-improve-rrule-integration branch October 8, 2024 13:31
This was referenced Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants