Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google-internal visibility change. #152

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions vmoe/projects/soft_moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
Results using this algorithm presented in the paper:
- "From Sparse to Soft Mixture of Experts" (https://arxiv.org/abs/2308.00951).
"""
from typing import Dict, Optional, Tuple
from typing import Dict, Mapping, Optional, Tuple

from absl import logging
import flax.linen as nn
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
from vmoe import moe
Expand Down Expand Up @@ -51,9 +52,13 @@ class SoftRouter(nn.Module):
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT
partition_spec: Optional[jax.sharding.PartitionSpec] = None
compute_similarity_metrics: bool = True
partitioning_rules: Optional[Mapping[str, Tuple]] = None # pylint: disable=g-bare-generic

@nn.compact
def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
if self.partitioning_rules:
inputs = nn_partitioning.with_sharding_constraint(
inputs, self.partitioning_rules['inputs'])
# Normalize inputs to have unit norm.
dtype = self.dtype or inputs.dtype
inputs = normalize(inputs.astype(dtype), axis=-1)
Expand All @@ -63,6 +68,10 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
num_slots = moe.compute_capacity(
group_size, self.num_experts, self.capacity_factor,
ceil_or_round='round', multiple_of=1)
logging.info(
'With num_tokens=%d, num_experts=%d, num_slots=%d, '
'capacity_factor=%f.', group_size, self.num_experts,
num_slots, self.capacity_factor)
else:
num_slots = self.num_slots
actual_capacity_factor = self.num_experts * num_slots / group_size
Expand All @@ -71,11 +80,22 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
'%sWith num_tokens=%d, num_experts=%d and num_slots=%d, the actual '
'capacity_factor is %f.', pre, group_size, self.num_experts,
self.num_slots, actual_capacity_factor)
mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots))
if self.partitioning_rules:
mu = nn_partitioning.param_with_axes(
'mu', self.mu_init, (dim, self.num_experts, num_slots),
axes=self.partitioning_rules['mu'])
else:
mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots))
mu = normalize(mu.astype(dtype), axis=0)
self.sow('intermediates', 'mu_unit', mu)
# Scale inputs/mu before computing the logits.
scale = self.param('scale', self.scale_init, ()).astype(dtype)
if self.partitioning_rules:
scale = nn_partitioning.param_with_axes(
'scale', self.scale_init, (),
axes=self.partitioning_rules['scale'])
else:
scale = self.param('scale', self.scale_init, ())
scale = scale.astype(dtype)
if inputs.size < mu.size:
inputs = inputs * scale
else:
Expand All @@ -89,6 +109,9 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]:
# Compute router logits between pairs of items (m) and total slots (n * p),
# independently on each group (g).
logits = jnp.einsum('gmd,dnp->gmnp', inputs, mu, precision=self.precision)
if self.partitioning_rules:
logits = nn_partitioning.with_sharding_constraint(
logits, self.partitioning_rules['logits'])
logits = self.add_noise(logits)
# Each slot takes a convex combination of the inputs.
dispatch_weights = jax.nn.softmax(logits, axis=1)
Expand Down