Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP+PP bug where reshard_after_forward must be true #1105

Open
wconstab opened this issue May 2, 2024 · 6 comments
Open

FSDP+PP bug where reshard_after_forward must be true #1105

wconstab opened this issue May 2, 2024 · 6 comments

Comments

@wconstab
Copy link
Contributor

wconstab commented May 2, 2024

https://github.com/pytorch/torchtitan/pull/161/files#diff-80b04fce2b861d9470c6160853441793678ca13904dae2a9b8b7145f29cd017aR269

image

IIRC @awgu mentioned there was an issue requiring this setting for the time being. Not sure why or if it has been fixed yet?

@awgu
Copy link
Contributor

awgu commented May 2, 2024

This seems like an important / high(er) priority issue since FSDP + PP generally wants reshard_after_forward=False.

@kwen2501
Copy link
Contributor

kwen2501 commented May 3, 2024

I believe in old FSDP, where FSDP API is called on the whole model, reshard_after_forward can be automatically figured out (or at least there is a way to do so).

I don't know if the new FSDP still allow the API to be called on the whole model or not, if allowed, can it be investigated so that this burden is not on the user? After all, reshard_after_forward is sort of an internal thing that requires certain level of understanding from the user about some "corner" procedure of FSDP.

That said, following @awgu 's comment, should we just do:

if pp:
    reshard_after_forward = False
else:
    reshard_after_forward = <a condition>

@awgu
Copy link
Contributor

awgu commented May 3, 2024

reshard_after_forward=True == ShardingStrategy.FULL_SHARD == ZeRO-3
reshard_after_forward=False == ShardingStrategy.SHARD_GRAD_OP == ZeRO-2

It is still the same (cannot be automatically figured out -- only the root module auto changes to reshard_after_forward=False since it will all-gather immediately to begin backward anyway). I would not consider to be a "corner" procedure of FSDP though. This is an important choice that affects the algorithm used, so generally users are aware of this.

@kwen2501
Copy link
Contributor

kwen2501 commented May 3, 2024

By "corner" case, I refer to this line:

reshard_after_forward = layer_id < len(model.layers) - 1

As compared to actively choosing ZeRo-2 or ZeRo-3, I think the user is more saying: I want to use FSDP, but I also want slightly more perf since the last layer's backward will immediately start after its forward so please don't reshard it.

@kwen2501
Copy link
Contributor

kwen2501 commented May 3, 2024

Said in a different way, if we already know that:
zero-3 + zero-3 + ... + zero-2
is going to be a common pattern, can we package that as an offering to our user?
Should that be considered a preferred implementation of zero-3(whole model)?

@awgu
Copy link
Contributor

awgu commented May 4, 2024

I see. I think since we do not know the execution order in general, we cannot do it easily in the FSDP API itself, which is a building block. Maybe a higher level API that knows how to call FSDP for some class of models could do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants