Numerically stable parallel cumsum-based WKV + jax/tf/keras implementations #189
jackd
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi there, love the work. Having dug through the paper in more detail recently I realized the WKV implementation has some similarities with some ongoing work I'm involved in, so I hacked up a proof of concept using keras / keras-nlp here. I've included a theory page, but to summarise:
w**(t-1)
, it can be expressed as a cumulative sum(v, t)
, where the actual valuez
is represented asz = exp(t) * v
I've included a very rough performance summary which show promise. That said:
This is a side-project of a side-project for me, so while I've enjoyed doing it I can't afford to spend much longer fine tuning a backend I understand little about. I'm pretty confident a cuda implementation based on thrust's inclusive_scan would be straight forward and perform considerably better than my triton implementation, but having never written custom pytorch bindings that's a project I'm going to pass on (if anyone decides to take that up I can offer a basic sketch).
Hope this helps someone :)
Beta Was this translation helpful? Give feedback.
All reactions