From eddb6aed9673be711030a2c6ab9c1d0e8472715b Mon Sep 17 00:00:00 2001 From: V-MoE Authors Date: Mon, 13 Nov 2023 05:35:19 -0800 Subject: [PATCH] google-internal visibility change. PiperOrigin-RevId: 581928522 --- vmoe/projects/soft_moe/router.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/vmoe/projects/soft_moe/router.py b/vmoe/projects/soft_moe/router.py index 1fbae84..32365ee 100644 --- a/vmoe/projects/soft_moe/router.py +++ b/vmoe/projects/soft_moe/router.py @@ -17,10 +17,11 @@ Results using this algorithm presented in the paper: - "From Sparse to Soft Mixture of Experts" (https://arxiv.org/abs/2308.00951). """ -from typing import Dict, Optional, Tuple +from typing import Dict, Mapping, Optional, Tuple from absl import logging import flax.linen as nn +from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from vmoe import moe @@ -51,9 +52,13 @@ class SoftRouter(nn.Module): precision: jax.lax.Precision = jax.lax.Precision.DEFAULT partition_spec: Optional[jax.sharding.PartitionSpec] = None compute_similarity_metrics: bool = True + partitioning_rules: Optional[Mapping[str, Tuple]] = None # pylint: disable=g-bare-generic @nn.compact def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]: + if self.partitioning_rules: + inputs = nn_partitioning.with_sharding_constraint( + inputs, self.partitioning_rules['inputs']) # Normalize inputs to have unit norm. dtype = self.dtype or inputs.dtype inputs = normalize(inputs.astype(dtype), axis=-1) @@ -63,6 +68,10 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]: num_slots = moe.compute_capacity( group_size, self.num_experts, self.capacity_factor, ceil_or_round='round', multiple_of=1) + logging.info( + 'With num_tokens=%d, num_experts=%d, num_slots=%d, ' + 'capacity_factor=%f.', group_size, self.num_experts, + num_slots, self.capacity_factor) else: num_slots = self.num_slots actual_capacity_factor = self.num_experts * num_slots / group_size @@ -71,11 +80,22 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]: '%sWith num_tokens=%d, num_experts=%d and num_slots=%d, the actual ' 'capacity_factor is %f.', pre, group_size, self.num_experts, self.num_slots, actual_capacity_factor) - mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots)) + if self.partitioning_rules: + mu = nn_partitioning.param_with_axes( + 'mu', self.mu_init, (dim, self.num_experts, num_slots), + axes=self.partitioning_rules['mu']) + else: + mu = self.param('mu', self.mu_init, (dim, self.num_experts, num_slots)) mu = normalize(mu.astype(dtype), axis=0) self.sow('intermediates', 'mu_unit', mu) # Scale inputs/mu before computing the logits. - scale = self.param('scale', self.scale_init, ()).astype(dtype) + if self.partitioning_rules: + scale = nn_partitioning.param_with_axes( + 'scale', self.scale_init, (), + axes=self.partitioning_rules['scale']) + else: + scale = self.param('scale', self.scale_init, ()) + scale = scale.astype(dtype) if inputs.size < mu.size: inputs = inputs * scale else: @@ -89,6 +109,9 @@ def __call__(self, inputs: Array) -> Tuple[BaseDispatcher, Dict[str, Array]]: # Compute router logits between pairs of items (m) and total slots (n * p), # independently on each group (g). logits = jnp.einsum('gmd,dnp->gmnp', inputs, mu, precision=self.precision) + if self.partitioning_rules: + logits = nn_partitioning.with_sharding_constraint( + logits, self.partitioning_rules['logits']) logits = self.add_noise(logits) # Each slot takes a convex combination of the inputs. dispatch_weights = jax.nn.softmax(logits, axis=1)