-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Experimental] expose dynamic upcasting of layers as experimental APIs #9949
Comments
thanks @sayakpaul - yes, i'm very interested! |
Yes that is why we think exposing this API even in an experimental capacity will make a lot of sense! |
Just to add my two cents and to make others aware of our in-person discussion, I don't think we could ever make it fully compatible with class LayerwiseUpcastingHook(ModelHook):
def __init__(self, compute_dtype: torch.dtype = torch.bfloat16, storage_dtype: torch.dtype = torch.float8_e5m2) -> None:
super().__init__()
self.compute_dtype = compute_dtype
self.storage_dtype = storage_dtype
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
set_args_dtype(args, dtype=self.compute_dtype)
set_kwargs_dtype(kwargs, dtype=self.compute_dtype)
module.to(dtype=self.compute_dtype)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
module.to(dtype=self.storage_dtype)
return output
# The below can be abstracted away to whatever makes sense
for block in pipe.transformer.transformer_blocks:
hook = LayerwiseUpcastingHook(compute_dtype=torch.bfloat16, storage_dtype=torch.float8_e5m2)
apply_hook_to_module(block, hook, append=True) This is similar to how I would like to also point out based on our discussion that this approach leads to more granular control and can lead to some more memory savings than normally doing it naively on transformer blocks would yield, for power users. Let me give an example for this: Our max memory required is bound by memory for weights + intermediate activations. Weight quantizing or using lower precision dtype for storage typically only helps reduce the memory for weights, but peaks from activations remain the same. Assuming we have a very simple model consisting of transformer blocks (attention + large feed forward), applying layerwise upcasting, from fp8 to fp16, at the transformer block level, will require (fp8 model memory + [fp16 single block memory - fp8 single block memory] + fp16 single block activation peak). However, if we do the upcasting more granularly at each attention and each feed forward layer, the max memory footprint will be lower, because you are only upcasting weights of the attention OR feed forward MLP at a time as opposed to upcasting both at once at transformer block level, but at a slightly slower inference speed tradeoff (because more upcasts and downcasts).The ultimate granular case would be like what we do in sequential cpu offload, but here we would upcast and downcast at the nn.Module level instead of ModelMixin level. I think these tradeoffs are still worth it for making larger and larger models more accessible for inference purposes, so thinking of layerwise upcasting at just the transformer block level is a bit restricting (just as we discussed). |
I don't think we're quite settled on the hooks-based approach yet, though, no? Like for more caching, we're still debating if a mixin class would make more case. OTOH, this simple experimental API is simple and easy to use and cuts it for most use cases as I can imagine.
As mentioned in the description, having the experimental API is still meaningful for cases when the weight is stored in a lower precision without any quantization stats and we don't need any bitwise packing and unpacking unlike quantization.
Well, |
I think this approach would be consistent with what accelerate does for cpu offloading and is very clean, since both device and dtype can be handled in the same way. It is also more easily usable with any kind of model because otherwise (as in @DN6's PR), it involves making many changes to each model by adding
If I understand the PR correctly, the entry point for enabling layerwise upcasting is at the ModelMixin level, yes? You can't arbitrarily apply it to any specific module that you want easily, unless it is already derived from ModelMixin. What would be really nice is keeping the API similar-ish to |
Okay, I see merit in the hook-based approach, but that said, we probably shouldn't also delay shipping this; otherwise, it is a simple EXPERIMENTAL feature because, clearly, it will enable a lot of memory savings very easily.
Doesn't need to be.
Good point! From what I understand those explicit |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
ping? |
Sorry about the delay @vladmandic! It was not planned for this release so I was going to pick it up later. Let me open a quick prototype with some benchmarks in some time |
thanks - its not an urgent thing, but its very interesting one - i just wanted to make sure it doesn't drop off the radar since there was no update in a long time. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Functionalities like #9177 are immensely helpful to load a checkpoint in say,
torch.float8_e5m2
, perform computation in say,torch.float16
, and then keep the result intorch.float8_e5m2
again.Even though this feature isn't immediately compatible with
torch.compile()
and we're unsure of its repercussions, we think it's still better to just have them as experimental APIs because the memory benefits are significant.Cc: @vladmandic as you expressed interest for this.
Cc: @a-r-r-o-w @SunMarc as we discussed it in-person.
Cc: @DN6 because #9177 is his brainchild.
The text was updated successfully, but these errors were encountered: