-
Notifications
You must be signed in to change notification settings - Fork 20
Add option for recomputing the casted weight during backwards #186
base: main
Are you sure you want to change the base?
Conversation
ed(f"call_method {self} {name} {args} {kwargs}")
[2024-01-12 18:38:24,732] [8/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/drisspg/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2024-01-12 18:38:24,732] [8/0] torch._dynamo.variables.higher_order_ops: [ERROR] raise Unsupported(msg)
[2024-01-12 18:38:24,732] [8/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: call_method GetAttrVariable(TensorVariable(), _data) stride [] {}
You love to see it! Why can't I call shape? avoding calling _data is toughhhhh |
cc @bdhirsh As far as I can till this is erroring because of these calls to the tensor attributes: https://github.com/pytorch-labs/float8_experimental/pull/186/files#diff-00f68398c8aad5a3e946cccd7211a80841da9403d6c664452a45e04101bea6d6R84-R93 I know that in the past anytime we try to access the subclasses attributes outside of the __torch__dispatch code this errors. I don't have any idea how to work around this since I think we need this autograd function and hence can't use the torch_dispatch. |
c7e3ce3
to
795ebbd
Compare
795ebbd
to
682c2e8
Compare
c3f5c9a
to
2ffcbe9
Compare
The above makes sense to me for this particular setting, if we choose to have a setting. It would be nice to not have a setting at all unless we need it. I feel like FSDP is unusable for real workloads without this, so if the recomputation is fast enough why not just have it as the only path? |
great! Can we also post throughput metrics on 8-gpu FSDP? If there is a slowdown, having a smaller benchmark to capture + debug it would be useful. |
e7a6aa3
to
d03d16b
Compare
26ce70d
to
d2da1ad
Compare
Summary
See: #185
For more detail
Disclaimer
Ughh idk, PT2 doesn't let me control what gets recomputed, I am having trouble interpreting the tea leaves
Currently ignore all the performance numbers below, expect for max memory usage. The min-cut-partitioner is actually undoing the recompute for backwards and saving the casted weight tensor. cc @Chillee
See: pytorch/pytorch#117901
Single GPU Linear numbers:
FSDP Memory Usage
Verified on single node 8-gpu FSDP that the memory usage is no longer scaling:
FSDP Performance
Using single node 8-gpu FSDP setup/compile
Single GPU Memory usage
In eager using this test script: https://gist.github.com/drisspg/75a792f97f5b8fa77f32af7f5280bae5
I am seeing a mac_memory used
Recompute = False:
Max Cuda Memory Used: 1.8438 GiB
Recompute = True:
Max Cuda Memory Used: 1.7032 GiB
A difference of ~0.14 gbs, We would should expect a memory saving of (
4096**2)*(1byte)*10(layers) * 1024**-3(bytes per GiB) = 0.15625
Also verified by memory-traces in the gist
Questions
This is kinda a meaty PR that depends on a PyTorch PR(pytorch/pytorch#117667) but I am curious if people have strong feelings on the "UX".
I chose not to make the "recompute weight cast" a config setting instead having it as a module attribute. The swap_linear will set this for every linear it swaps, in theory
from_float
is granular enough to do this on a per linear basis.Is there any reason why having it has a global config would be better, (even a global config setting that alters the swap_functions behavior?)