diff --git a/atorch/examples/moe/train_moe_dummy_data.py b/atorch/examples/moe/train_moe_dummy_data.py index 805d75c38..db6837e77 100644 --- a/atorch/examples/moe/train_moe_dummy_data.py +++ b/atorch/examples/moe/train_moe_dummy_data.py @@ -63,7 +63,27 @@ def _route_tokens(self, router_logits: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits = self._compute_router_probabilities(hidden_states) router_probs, topk_experts_index = self._route_tokens(router_logits) - return router_probs, router_logits, topk_experts_index + + # Calculate auxiliary loss for balancing + aux_loss = self._compute_auxiliary_loss(router_probs) + + return router_probs, router_logits, topk_experts_index, aux_loss + + def _compute_auxiliary_loss(self, router_probs): + # Implement the auxiliary loss calculation + # This is a placeholder for the actual implementation + # You might want to use the entropy of the router_probs or any other metric + return torch.mean(router_probs) + # Calculate auxiliary loss for balancing + aux_loss = self._compute_auxiliary_loss(router_probs) + + return router_probs, router_logits, topk_experts_index, aux_loss + + def _compute_auxiliary_loss(self, router_probs): + # Implement the auxiliary loss calculation + # This is a placeholder for the actual implementation + # You might want to use the entropy of the router_probs or any other metric + return torch.mean(router_probs) class _MLP(nn.Module): @@ -130,7 +150,7 @@ def __init__(self, config): self.shared_experts = None def forward(self, hidden_states): - router_probs, router_logits, top_expert_index = self.router(hidden_states) + router_probs, router_logits, top_expert_index, aux_loss = self.router(hidden_states) identify = hidden_states if self.shared_experts is not None and self.use_expert_parallelism: @@ -162,7 +182,8 @@ def forward(self, hidden_states): if self.shared_experts is not None and not self.use_expert_parallelism: hidden_states = hidden_states + self.shared_experts(identify) - return hidden_states + # Return the auxiliary loss along with the hidden states + return hidden_states, aux_loss def patch_llama(args): @@ -301,6 +322,7 @@ def parse_args(): parser.add_argument("--not_use_atorch_rmsnorm", default=False, action="store_true") parser.add_argument("--use_meta_init", default=False, action="store_true") parser.add_argument("--use_distributed_dataloader", default=False, action="store_true") + parser.add_argument("--aux_loss_weight", type=float, default=0.01, help="Weight for the auxiliary loss") parser.add_argument("--shared_expert_overlapping", default=False, action="store_true") parser.add_argument("--max_checkpoint_module_num", type=int, default=-1, required=False) parser.add_argument("--record_timeline", default=False, action="store_true") @@ -398,6 +420,14 @@ def train_model( for batch in dataloader: optim.zero_grad() batch = prepare_input(batch, device) + outputs, aux_loss = model(**batch) + loss = loss_func(batch, outputs) + + # Add auxiliary loss to the total loss + total_loss = loss + args.aux_loss_weight * aux_loss + + total_loss.backward() + optim.step() outputs = model(**batch) loss = loss_func(batch, outputs) loss.backward() @@ -568,3 +598,4 @@ def set_global_variable_from_args(args): from atorch import npu # noqa patch_llama(args) train(args) +