Skip to content

Commit

Permalink
apply z-loss and aux-loss only during training
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 12, 2024
1 parent 8aea96f commit 9facee4
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions aria/model/moe_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def routing(
- top_indices: Indices of top-k experts for each token.
- tokens_per_expert: Number of tokens assigned to each expert.
"""
logits = self.apply_z_loss(logits)
if self.training:
logits = self.apply_z_loss(logits)

top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)
Expand All @@ -267,7 +268,8 @@ def routing(
max=self.config.moe_num_experts - 1,
)

scores = self.apply_aux_loss(logits, tokens_per_expert, scores)
if self.training:
scores = self.apply_aux_loss(logits, tokens_per_expert, scores)
return scores, top_indices, tokens_per_expert

def forward(
Expand Down

0 comments on commit 9facee4

Please sign in to comment.