-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] [Fix-Suggested] ZeRO Stage 3 Overwrites Module ID Attribute Causing Incorrect Expert Placement on GPUs #6772
[BUG] [Fix-Suggested] ZeRO Stage 3 Overwrites Module ID Attribute Causing Incorrect Expert Placement on GPUs #6772
Comments
@traincheck-team, thanks for providing detailed report and suggesting reasonable solutions. We will work on this immediately. |
@traincheck-team, please see #6847. I simplified your repro into a simple unit test. Please advise if this simplification is missing crucial aspects of this issue. Thanks! |
Thanks for the fast response! @tjruwase |
Thanks for the quick review.
Yes, for this problem, I think your PR is best solution as discussed. Thanks! |
Fix #6772 --------- Co-authored-by: Logan Adams <[email protected]>
Fix #6772 --------- Co-authored-by: Logan Adams <[email protected]> Signed-off-by: Olatunji Ruwase <[email protected]>
Fix deepspeedai#6772 --------- Co-authored-by: Logan Adams <[email protected]> Signed-off-by: siqi <[email protected]>
Fix deepspeedai#6772 --------- Co-authored-by: Logan Adams <[email protected]> Signed-off-by: Bruno Magalhaes <[email protected]>
Description
We experienced wrong GPU placement when doing MoE with ZeRO Stage 3. We use
module.id
to control which expert to be loaded onto which GPU for finegrained controlm and we find out thatmodule.id
got corrupted afterdeepspeed.initialize
.Suspected Root Cause
DeepSpeed uses
.id
in ZeRO Stage 3 optimization to manage states, as seen inruntime/zero/parameter_offload.py:L271
.This practice is very brittle in that:
id
is an overly generic attribute name, might get easilly collided with some user-defined attributes..id
attribute before setting it, this allows for accidental overwrites of the attribute, causing hard-to-diagnose problems.In the specific bug we've encountered (bug.py provided below), each expert module is identified by the
.id
attribute, but during initialization, the.id
is overwritten by the_register_hooks_recursively
function indeepspeed/runtime/zero/stage3.py
, leading to a mess on expert-GPU placement.To reproduce
The following code in ZeRO Stage 3 is responsible for overwriting the
.id
attribute:Install deepspeed
0.15.4
run
bug.py
usingdeepspeed --num_gpus=2 bug.py
(num_gpus argument here doesn't matter, use 1 if you don't have multigpu nodes.)id
s of all experts twice, one before deepspeed.initialize and one after that. Observe that the first print gives0, 1, 2, ..., 59
while the second one gives2, 4, 6, 8, .., 120
In this code,
module.id
is set to a value based on a counter (my_count
), which conflicts with user-defined.id
attributes used for expert placement.Bug Significance
This bug can significantly affect model behavior when expert modules are incorrectly placed across GPUs, leading to incorrect training outcomes or potential crashes. Ensuring that internal DeepSpeed modifications do not overwrite user-defined attributes is crucial for stability and expected functionality.
Even if user-side conflicts are not in your scope, deepspeed itself can accidently modify these attributes as well. For example, you can reproduce the same problem by calling
deepspeed.initialize
multiple times.Thus, we argue for two fixes / engineering practices for this issue.
Expected Behavior / Suggested Fix
.id
, use a more specific attribute name such as_deepspeed_id
to avoid conflicts with user-defined attributes.__setattr__
method to only allow setting fields that have not been previously set, preventing unintentional overwrites of user-defined attributes.deepspeed.initialize
: We observe a lot of issue with accidental duplicate calls todeepspeed.initialize
. Thus we suggest to forbid duplicate calls by recording the models / optimizers that have already been inited, as mentioned in [BUG] [Fix-Suggested] KeyError in stage_1_and_2.py Due to Optimizer-Model Parameter Mismatch #6770 .ds_report output
Click to Show
I will be more than happy to contribute to the two suggested fixes, let me know what you think!
The text was updated successfully, but these errors were encountered: