Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Balance Loss to MoE Example for Enhanced Expert Load Distribution (Issue #1300) #1311

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Mukku27
Copy link

@Mukku27 Mukku27 commented Oct 26, 2024

Add Balance Loss to MoE Example for Enhanced Expert Load Distribution (Issue #1300)


What changes were proposed in this pull request?

This pull request proposes the integration of a balance loss mechanism into the Mixture-of-Experts (MoE) example in the atorch codebase. Specifically, an auxiliary loss has been added to the TopNRouter class to facilitate balanced load distribution across experts, improving model performance and efficiency. Key modifications include:

  1. Router Updates:

    • Updated the TopNRouter class to compute and return an auxiliary loss based on router probabilities, helping distribute tokens more evenly.
    • Added _compute_auxiliary_loss() method to calculate this auxiliary loss, which is currently set up to use the mean of router probabilities as a placeholder. This can be customized based on specific balancing requirements.
  2. MoE Layer Enhancements:

    • Modified the _SparseMLP class to incorporate auxiliary loss from the router and propagate it back, enhancing the MoE layer’s load-balancing capabilities.
  3. Training Loop Modifications:

    • Integrated the auxiliary loss into the training loop by including it in the total loss computation. This ensures that the auxiliary loss is factored into backpropagation, aiding in balanced expert utilization.
  4. Auxiliary Loss Weight Configurability:

    • Added a command-line argument, --aux_loss_weight, to allow users to adjust the weight of the auxiliary loss as needed. This flexibility enables fine-tuning of the loss function based on model requirements.

Why are the changes needed?

These changes address issue #1300 by introducing an auxiliary balance loss mechanism to the MoE example, which aims to improve the distribution of workload across experts. In multi-expert architectures, load imbalances can lead to inefficiencies and underutilized resources, ultimately impacting model performance. The proposed auxiliary loss provides a straightforward way to mitigate these imbalances, enhancing both training efficiency and overall model effectiveness.

Does this PR introduce any user-facing change?

Yes, this PR introduces a new command-line argument, --aux_loss_weight, which allows users to adjust the weight of the auxiliary loss as needed. By default, it is set to 0.01 but can be configured according to specific model or training needs.

How was this patch tested?

The patch was tested through the following steps:

  • Functional Testing: Verified that the auxiliary loss is calculated and returned as expected in the TopNRouter class and is correctly incorporated into the MoE layer.
  • Integration Testing: Confirmed that the training loop correctly computes total loss with the auxiliary loss and backpropagates without errors.
  • Command-line Argument Testing: Ensured that the --aux_loss_weight argument correctly adjusts the auxiliary loss weight in different runs.

These changes are expected to contribute to improved expert load balancing, benefiting users who require scalable and efficient MoE models.

@Mukku27
Copy link
Author

Mukku27 commented Oct 28, 2024

@skydoorkai @adamantboy @hxdtest
Please review the changes, and if everything looks good, I would appreciate it if you could merge the PR.If any improvements required please spectify

Copy link
Collaborator

@merlintang merlintang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

codecov bot commented Oct 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.52%. Comparing base (0a77136) to head (988d0f8).
Report is 3 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1311   +/-   ##
=======================================
  Coverage   80.51%   80.52%           
=======================================
  Files         222      222           
  Lines       20698    20707    +9     
=======================================
+ Hits        16666    16674    +8     
- Misses       4032     4033    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -162,7 +182,8 @@ def forward(self, hidden_states):
if self.shared_experts is not None and not self.use_expert_parallelism:
hidden_states = hidden_states + self.shared_experts(identify)

return hidden_states
# Return the auxiliary loss along with the hidden states
return hidden_states, aux_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is based on transformers llama model.
llama mlp definition only returns hidden_states, return tuple won't work.


return router_probs, router_logits, topk_experts_index, aux_loss

def _compute_auxiliary_loss(self, router_probs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to implement a real aux loss that works in this example, not a placeholder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants