Skip to content

Commit

Permalink
Delay importing deepspeed comm due for perf (huggingface#810)
Browse files Browse the repository at this point in the history
Co-authored-by: Jinyan Chen <[email protected]>
Co-authored-by: regisss <[email protected]>
  • Loading branch information
3 people authored Mar 17, 2024
1 parent 6e36e18 commit c7a5498
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
Expand Down Expand Up @@ -61,12 +62,6 @@
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

try:
from deepspeed import comm as dist
except ImportError:
print("Not using HPU DeepSpeed.")
dist = None

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -294,10 +289,13 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) ->
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if dist and dist.is_initialized():
output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, router_logits)
router_logits = torch.cat(output_tensors, dim=1)
if is_deepspeed_available():
from deepspeed import comm as dist

if dist.is_initialized():
output_tensors = [router_logits.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, router_logits)
router_logits = torch.cat(output_tensors, dim=1)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
Expand Down

0 comments on commit c7a5498

Please sign in to comment.