Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Balance Loss to MoE Example for Enhanced Expert Load Distribution (Issue #1300) #1311

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions atorch/examples/moe/train_moe_dummy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,27 @@ def _route_tokens(self, router_logits: torch.Tensor):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits = self._compute_router_probabilities(hidden_states)
router_probs, topk_experts_index = self._route_tokens(router_logits)
return router_probs, router_logits, topk_experts_index

# Calculate auxiliary loss for balancing
aux_loss = self._compute_auxiliary_loss(router_probs)

return router_probs, router_logits, topk_experts_index, aux_loss

def _compute_auxiliary_loss(self, router_probs):
# Implement the auxiliary loss calculation
# This is a placeholder for the actual implementation
# You might want to use the entropy of the router_probs or any other metric
return torch.mean(router_probs)
# Calculate auxiliary loss for balancing
aux_loss = self._compute_auxiliary_loss(router_probs)

return router_probs, router_logits, topk_experts_index, aux_loss

def _compute_auxiliary_loss(self, router_probs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to implement a real aux loss that works in this example, not a placeholder.

# Implement the auxiliary loss calculation
# This is a placeholder for the actual implementation
# You might want to use the entropy of the router_probs or any other metric
return torch.mean(router_probs)


class _MLP(nn.Module):
Expand Down Expand Up @@ -130,7 +150,7 @@ def __init__(self, config):
self.shared_experts = None

def forward(self, hidden_states):
router_probs, router_logits, top_expert_index = self.router(hidden_states)
router_probs, router_logits, top_expert_index, aux_loss = self.router(hidden_states)
identify = hidden_states

if self.shared_experts is not None and self.use_expert_parallelism:
Expand Down Expand Up @@ -162,7 +182,8 @@ def forward(self, hidden_states):
if self.shared_experts is not None and not self.use_expert_parallelism:
hidden_states = hidden_states + self.shared_experts(identify)

return hidden_states
# Return the auxiliary loss along with the hidden states
return hidden_states, aux_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is based on transformers llama model.
llama mlp definition only returns hidden_states, return tuple won't work.



def patch_llama(args):
Expand Down Expand Up @@ -301,6 +322,7 @@ def parse_args():
parser.add_argument("--not_use_atorch_rmsnorm", default=False, action="store_true")
parser.add_argument("--use_meta_init", default=False, action="store_true")
parser.add_argument("--use_distributed_dataloader", default=False, action="store_true")
parser.add_argument("--aux_loss_weight", type=float, default=0.01, help="Weight for the auxiliary loss")
parser.add_argument("--shared_expert_overlapping", default=False, action="store_true")
parser.add_argument("--max_checkpoint_module_num", type=int, default=-1, required=False)
parser.add_argument("--record_timeline", default=False, action="store_true")
Expand Down Expand Up @@ -398,6 +420,14 @@ def train_model(
for batch in dataloader:
optim.zero_grad()
batch = prepare_input(batch, device)
outputs, aux_loss = model(**batch)
loss = loss_func(batch, outputs)

# Add auxiliary loss to the total loss
total_loss = loss + args.aux_loss_weight * aux_loss

total_loss.backward()
optim.step()
outputs = model(**batch)
loss = loss_func(batch, outputs)
loss.backward()
Expand Down Expand Up @@ -568,3 +598,4 @@ def set_global_variable_from_args(args):
from atorch import npu # noqa
patch_llama(args)
train(args)

Loading