Skip to content

Commit

Permalink
reduce cpu host overhead when using moe (#5578)
Browse files Browse the repository at this point in the history
The operation `.to('cpu') `is not necessary for exp_counts, and it will
cause device to host synchronization which damage performance.

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
ranzhejiang and tjruwase authored Aug 21, 2024
1 parent 8b191d7 commit 7260890
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def top1gating(logits: Tensor,
mask1 = einsum("s,se->se", used_token, mask1)

# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device)

# if we don't want to drop any tokens
if not drop_tokens:
Expand Down Expand Up @@ -324,7 +324,7 @@ def top2gating(logits: Tensor,
l_aux = torch.mean(me * ce) * num_experts * num_experts

# gating decisions
exp_counts = torch.sum(mask1 + mask2, dim=0)
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)

if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
Expand Down Expand Up @@ -368,7 +368,7 @@ def top2gating(logits: Tensor,
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')
return l_aux, combine_weights, dispatch_mask, exp_counts


def topkgating(
Expand Down

0 comments on commit 7260890

Please sign in to comment.