diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index 46f7924ac038..dfa9fcf4f464 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -32,6 +32,7 @@ class MoE(nn.Module): use_rts (bool, optional): default=True, whether to use Random Token Selection. use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed). enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts + top2_2nd_expert_sampling (bool, optional): default=True, whether to perform sampling for 2nd expert """ def __init__(self, @@ -48,7 +49,8 @@ def __init__(self, drop_tokens: bool = True, use_rts: bool = True, use_tutel: bool = False, - enable_expert_tensor_parallelism: bool = False) -> None: + enable_expert_tensor_parallelism: bool = False, + top2_2nd_expert_sampling: bool = True) -> None: super(MoE, self).__init__() @@ -69,7 +71,8 @@ def __init__(self, experts = Experts(expert, self.num_local_experts, self.expert_group_name) self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor, - min_capacity, noisy_gate_policy, drop_tokens, use_rts), + min_capacity, noisy_gate_policy, drop_tokens, use_rts, + top2_2nd_expert_sampling), experts, self.expert_group_name, self.ep_size, diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index e6a5292d7e4f..d6c023ec11d3 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -210,6 +210,11 @@ def top1gating(logits: Tensor, if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) capacity = new_capacity # Compute l_aux @@ -275,23 +280,27 @@ def top1gating(logits: Tensor, return l_aux, combine_weights, dispatch_mask, exp_counts -def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +def top2gating(logits: Tensor, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = True, + top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = F.softmax(logits, dim=1) - capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) - # Create a mask for 1st's expert per token indices1_s = torch.argmax(gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) - # Create a mask for 2nd's expert per token using Gumbel-max trick - # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ - logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + if top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += gumbel_rsample(logits.shape, device=logits.device) + # Replace top-expert with min value - logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) indices2_s = torch.argmax(logits_except1, dim=1) mask2 = F.one_hot(indices2_s, num_classes=num_experts) @@ -301,17 +310,29 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) - # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') - # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) l_aux = torch.mean(me * ce) * num_experts * num_experts - # Remove locations outside capacity from mask - mask1 *= torch.lt(locations1, capacity) - mask2 *= torch.lt(locations2, capacity) + # gating decisions + exp_counts = torch.sum(mask1 + mask2, dim=0) + + if drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) + mask1 *= torch.lt(locations1, capacity) + mask2 *= torch.lt(locations2, capacity) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = torch.max(exp_counts) + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) + capacity = new_capacity # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) @@ -338,7 +359,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup combine_weights = combine1_sec + combine2_sec dispatch_mask = combine_weights.bool() - return l_aux, combine_weights, dispatch_mask, exp_counts + return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu') class TopKGate(Module): @@ -368,13 +389,14 @@ def __init__(self, min_capacity: int = 8, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, - use_rts: bool = True) -> None: + use_rts: bool = True, + top2_2nd_expert_sampling: bool = True) -> None: super().__init__() # Only top-1 and top-2 are supported at the moment. if k != 1 and k != 2: raise ValueError('Only top-1 and top-2 gatings are supported.') - self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor @@ -385,6 +407,7 @@ def __init__(self, self.gate_time = 0.0 self.drop_tokens = drop_tokens self.use_rts = use_rts + self.top2_2nd_expert_sampling = top2_2nd_expert_sampling def forward(self, input: torch.Tensor, @@ -394,13 +417,11 @@ def forward(self, if self.wall_clock_breakdown: self.timers(TOPK_GATE_TIMER).start() - if self.wg.weight.dtype != torch.float32: - self.wg = self.wg.float() input_fp32 = input.float() # input jittering if self.noisy_gate_policy == 'Jitter' and self.training: input_fp32 = multiplicative_jitter(input_fp32, device=input.device) - logits = self.wg(input_fp32) + logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None) if self.k == 1: gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, @@ -409,7 +430,7 @@ def forward(self, else: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity) + self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling) if self.wall_clock_breakdown: self.timers(TOPK_GATE_TIMER).stop()