From d4a800542bf2ff48c83b48b379768547c07d7782 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Sun, 18 Dec 2022 08:54:19 -0800 Subject: [PATCH] Internal changes. PiperOrigin-RevId: 496228831 --- vmoe/partitioning.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vmoe/partitioning.py b/vmoe/partitioning.py index 9506b32..1b9d5ba 100644 --- a/vmoe/partitioning.py +++ b/vmoe/partitioning.py @@ -84,6 +84,15 @@ UnparsedPartitionSpec = Union[str, Tuple[Union[str, Tuple[str, ...]], ...]] +def get_array_sharding_or_default(arr: jax.Array) -> sharding.Sharding: + if hasattr(arr, 'sharding'): + return arr.sharding + else: + op_sharding = jax.xla.xc.OpSharding() + op_sharding.type = jax.xla.xc.OpSharding.Type.REPLICATED + return sharding.OpShardingSharding(jax.devices(), op_sharding) + + def process_has_contiguous_device_slice(devices: np.ndarray, process_index: int) -> bool: """Checks if the devices of a process form a contiguous slice in the mesh."""