diff --git a/aria/model/moe_lm.py b/aria/model/moe_lm.py index c88efe0..ad9d117 100644 --- a/aria/model/moe_lm.py +++ b/aria/model/moe_lm.py @@ -255,7 +255,8 @@ def routing( - top_indices: Indices of top-k experts for each token. - tokens_per_expert: Number of tokens assigned to each expert. """ - logits = self.apply_z_loss(logits) + if self.training: + logits = self.apply_z_loss(logits) top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits) @@ -267,7 +268,8 @@ def routing( max=self.config.moe_num_experts - 1, ) - scores = self.apply_aux_loss(logits, tokens_per_expert, scores) + if self.training: + scores = self.apply_aux_loss(logits, tokens_per_expert, scores) return scores, top_indices, tokens_per_expert def forward(