Skip to content

Commit

Permalink
Internal changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496228831
  • Loading branch information
jpuigcerver authored and copybara-github committed Dec 18, 2022
1 parent 7c168db commit 142815c
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions vmoe/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@
UnparsedPartitionSpec = Union[str, Tuple[Union[str, Tuple[str, ...]], ...]]


def get_array_sharding_or_default(arr: jax.Array) -> sharding.Sharding:
if hasattr(arr, 'sharding'):
return arr.sharding
else:
op_sharding = jax.xla.xc.OpSharding()
op_sharding.type = jax.xla.xc.OpSharding.Type.REPLICATED
return sharding.OpShardingSharding(jax.devices(), op_sharding)


def process_has_contiguous_device_slice(devices: np.ndarray,
process_index: int) -> bool:
"""Checks if the devices of a process form a contiguous slice in the mesh."""
Expand Down

0 comments on commit 142815c

Please sign in to comment.