Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is dividing by e**(1/4) for both keys and queries more memory efficient? #5

Open
esvhd opened this issue Aug 28, 2019 · 4 comments

Comments

@esvhd
Copy link

esvhd commented Aug 28, 2019

Hi,

Would you mind explaining why the follow code is more memory efficient than just dividing one of them by sqrt(e)?

former/former/modules.py

Lines 48 to 52 in 7b12ae6

queries = queries / (e ** (1/4))
keys = keys / (e ** (1/4))
# - Instead of dividing the dot products by sqrt(e), we scale the keys and values.
# This should be more memory efficient

Thank you.

@pbloem
Copy link
Owner

pbloem commented Aug 29, 2019

Sure. This is because pytorch retains the history of all computation steps (it remembers all inputs and outputs, to compute the gradients).

For long sequences, the memory use of a transformer is dominated by the matrix dot (the dot product of all queries and keys). Every time you operate on this matrix pytorch needs to remember another t * t float values. If we move the scaling to the keys and queries, it needs to remember only 2 * t * e float values. If we set t (much) bigger than e, this is more efficient.

I didn't test this very thoroughly, so there may be some optimizations that I don't know about. But some quick tests with the debugger seem to bear this out.

@pbloem pbloem closed this as completed Aug 29, 2019
@esvhd
Copy link
Author

esvhd commented Aug 29, 2019 via email

@TheGrayFrost
Copy link

TheGrayFrost commented Jun 24, 2020

Does this not make the dividing square-root thing an artefact of pytorch implementation rather than a general thing?
In tensorflow when you build the computation graph, any memory that is not re-used (like the older value of dot) could be modified in place. I thought pytorch also did hacks like this, since by writing dot = dot / (e ** (1/2)), you don't really need to keep the old value of dot - nobody can ever ask for the old value back. And in general also, pytorch code is as performant as tensorflow, so I thought its dynamic graph would do this...
But you saw that the debugger showed more memory usage in case of the dot thing...
Could you please tell which debuggers or profilers you use for pytorch? I will try to investigate this more...

I think we gain a computational advantage also by not having to divide the bigger dot matrix. And if this is the case, we do even less computation by only dividing either keys or queries by e ** (1/2), rather than both by e ** (1/4) (at the cost of looking asymmetrical).

Could you please share your thoughts on this?

@pbloem
Copy link
Owner

pbloem commented Jun 25, 2020

I think this is easier to do with a static computation graph (like in tf 1.0). If you build the graph dynamically, and do something like

dot1 = f(x)
dot2 = g(dot)

you don't know whether dot1 might be re-used in some other branch of the computation. With a static computation graph you can do more optimization during compilation when you know that dot1 is never used again.

With

dot = f(x)
dot = g(dot)

you can tell by inspecting the code, of course, that the first dot object will never be used but I don't think pytorch has that kind of runtime access to the structure of the code. Also, I don't know if you could easily compress the two modules to save memory for the backward in a dynamic computation graph.

However, thinking about this again, for multiplying by a constant you can work out the gradients without storing the input values, so it shouldn't matter at any rate. I don't remember how I tested the memory use, but I did notice a clear jump in memory. I'll reopen this ticket to try again.

As for moving the multiplication even further back, I don't expect it would make a big difference. Multiplying by a constant will be at most linear in the size of the dot matrix, so it will vanish compared to multiplying the dot matrix by the values. Still, it might be worth testing to see what the impact is.

@pbloem pbloem reopened this Jun 25, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants