-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradients for Flux etc.-- WIP #59
Conversation
Codecov Report
@@ Coverage Diff @@
## master #59 +/- ##
==========================================
- Coverage 77.15% 75.41% -1.74%
==========================================
Files 13 16 +3
Lines 1383 1509 +126
==========================================
+ Hits 1067 1138 +71
- Misses 316 371 +55
Continue to review full report at Codecov.
|
Codecov Report
@@ Coverage Diff @@
## master #59 +/- ##
=========================================
- Coverage 75.52% 74.02% -1.5%
=========================================
Files 13 16 +3
Lines 1344 1467 +123
=========================================
+ Hits 1015 1086 +71
- Misses 329 381 +52
Continue to review full report at Codecov.
|
While this looks like great work you are doing, I would like to inform/warn you that I don't have a lot of time to review this in the next few weeks, especially since I am also pretty new to automatic differentation. I have a keen interest in this topic, and I am open to supporting it with TensorOperations.jl, but I would actually have hoped that some of the more specific compatibility code could actually live in the corresponding packages, not here. Having it here means also a much higher responsibility from my part to keep it up and running, and might make future innovations that I would like to implement a lot harder. For example, I am not convinced that the way I currently lower the |
I dropped the ball here, sorry... but at least you know there is no rush! Maybe RFC would have been a better title for this. It was basically a challenge to see if I could put this together, and now it works, more or less. I learned quite a bit, including from your code. My PR suggestion that it live here was I guess because it has to depend on lots of internal details of your functions, not just what’s “public”. But I hear you about maintenance load, and don’t mean to impose on you. It also necessarily depends on many things in Flux etc. which look pretty different from how they did a year ago… it’s a bit awkward really that it must do both, maybe there will eventually be a nicer way. Indeed there may be smarter ways of going about this whole thing. I haven't thought very hard about this, it could be that gradients should be defined either on a higher or a lower level than these functions. That was just what seemed within reach -- 5 minutes on a piece of paper for these 3 basic operations. For now perhaps I leave it here & we can see? If it's useful to anyone perhaps they could chime in too. |
@mcabbott Thanks for you PR, it solved my problem, so far every thing works fine! |
This code is now at https://github.com/mcabbott/TensorTrack.jl instead, and a higher-level approach to the same problem is at https://github.com/mcabbott/TensorGrad.jl . |
I know that reference and have wanted to implement this in TensorOperations.jl for a long time already. In a sense, I consider that reference as a first step towards AD for tensor networks 'avant la lettre'. In fact, I started working on an implementation last week and had thought a bit about the interface. I was thinking of something like
which would create a callable object This would work perfectly with the gradients from your |
What's the status on this? This would save me a lot of hell if it worked. I'm happy to devote any amount of time to getting this working |
The gradient definitions from TensorGrad.jl (linked above) are presently bolted onto Tullio.jl. It's perhaps a slightly awkward combination but useful for now. BTW, the design there is that there's a callable struct |
UPDATES: @Jutho , It is probably much easier to define the backward rules on binary tensor contraction operations. Because the nary tensor operations calls the binary contraction function, the contraction order in the backward pass can be determined automatically by the AD engine. It always has constant overheads w.r.t. the forward pass. FYI: It is much easier to support AD now, because the AD community has switched to ChainRules, you just add a very small dependency: ChainRulesCore. (however, we still do not have a reliable AD engine in Julia 😞 ) |
I believe I found the bug in my gradient calculation today, and so I tidied it up a little for this PR. (See also #47.) The code specific to Flux is loaded via Requires, while the gradients themselves are in another file.
α
,β
, only the arrays. I would likeα::TrackedReal
to be an error, which I think will require doing slightly less promotion, but I don't precisely see where yet.trace∇A
callscontract!
, it might be helpful iftrace!
was handed somesyms
by the macro, ascontract!
is; they could be ignored elsewhere.