Skip to content

Commit

Permalink
google-internal visibility change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581928522
  • Loading branch information
V-MoE Authors authored and copybara-github committed Nov 13, 2023
1 parent efcf732 commit f45b265
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions vmoe/projects/soft_moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
Results using this algorithm presented in the paper:
- "From Sparse to Soft Mixture of Experts" (https://arxiv.org/abs/2308.00951).
"""
from typing import Dict, Optional, Tuple
from typing import Dict, Mapping, Optional, Tuple

from absl import logging
import flax.linen as nn
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
from vmoe import moe
Expand Down Expand Up @@ -51,9 +52,13 @@ class SoftRouter(nn.Module):
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT
partition_spec: Optional[jax.sharding.PartitionSpec] = None
compute_similarity_metrics: bool = True
partitioning_rules: Optional[Mapping[str, Tuple]] = None # pylint: disable=g-bare-generic

@nn.compact
def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
if self.partitioning_rules:
inputs = nn_partitioning.with_sharding_constraint(
inputs, self.partitioning_rules['inputs'])
# Normalize inputs to have unit norm.
dtype = self.dtype or inputs.dtype
inputs = normalize(inputs.astype(dtype), axis=-1)
Expand All @@ -63,6 +68,10 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
num_slots = moe.compute_capacity(
group_size, self.num_experts, self.capacity_factor,
ceil_or_round='round', multiple_of=1)
logging.info(
'With num_tokens=%d, num_experts=%d, num_slots=%d, '
'capacity_factor=%f.', group_size, self.num_experts,
num_slots, self.capacity_factor)
else:
num_slots = self.num_slots
actual_capacity_factor = self.num_experts * num_slots / group_size
Expand All @@ -71,11 +80,22 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
'%sWith num_tokens=%d, num_experts=%d and num_slots=%d, the actual '
'capacity_factor is %f.', pre, group_size, self.num_experts,
self.num_slots, actual_capacity_factor)
mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots))
if self.partitioning_rules:
mu = nn_partitioning.param_with_axes(
'mu', self.mu_init, (dim, self.num_experts, num_slots),
axes=self.partitioning_rules['mu'])
else:
mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots))
mu = normalize(mu.astype(dtype), axis=0)
self.sow('intermediates', 'mu_unit', mu)
# Scale inputs/mu before computing the logits.
scale = self.param('scale', self.scale_init, ()).astype(dtype)
if self.partitioning_rules:
scale = nn_partitioning.param_with_axes(
'scale', self.scale_init, (),
axes=self.partitioning_rules['scale'])
else:
scale = self.param('scale', self.scale_init, ())
scale = scale.astype(dtype)
if inputs.size < mu.size:
inputs = inputs * scale
else:
Expand All @@ -89,6 +109,9 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
# Compute router logits between pairs of items (m) and total slots (n * p),
# independently on each group (g).
logits = jnp.einsum('gmd,dnp->gmnp', inputs, mu, precision=self.precision)
if self.partitioning_rules:
logits = nn_partitioning.with_sharding_constraint(
logits, self.partitioning_rules['logits'])
logits = self.add_noise(logits)
# Each slot takes a convex combination of the inputs.
dispatch_weights = jax.nn.softmax(logits, axis=1)
Expand Down

0 comments on commit f45b265

Please sign in to comment.