-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a high level API for distribution for JAX backend. (#741)
* Init commit for distributed training with JAX. * WIP * Add unit test for data parallel distribution. * Update the TODO message in the test. * Reduce he scope of the XLA flag in unit test. * Update unit test for setup/teardown * Add JAX xla backend reset for unit test * Multiple updates to the distribution. 1. Rename the distribute.py to distribution.py (same for tests). 2. Merge the global state logic to the distribution class. 3. Update all the unit tests. * Updates. * Updates * Formatting * Update test for debug setup/teardown * Further debug for unit test * lift the config logic out for testing * More debug for the unit test cleanup * Fix the unit test with warning. * Address review comments. * Address review comments.
- Loading branch information
Showing
4 changed files
with
321 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""!!!DO NOT USE!!! | ||
Distribution related class for JAX backend. | ||
This is just a prototype and we might want to unify it in future for other | ||
backends. | ||
""" | ||
|
||
import contextlib | ||
|
||
from absl import logging | ||
import jax | ||
import numpy as np | ||
|
||
from keras_core.backend.common import global_state | ||
|
||
DEFAULT_BATCH_DIM_NAME = "batch" | ||
GLOBAL_ATTRIBUTE_NAME = "distribution" | ||
|
||
|
||
def get_global_distribution(): | ||
"""Retrieve the current distribution from global context.""" | ||
return global_state.get_global_attribute(GLOBAL_ATTRIBUTE_NAME) | ||
|
||
|
||
def set_global_distribution(distribution): | ||
"""Set the distribution as the global distribution setting.""" | ||
# TODO(qlzh727): Do type checking for input once we have a base class. | ||
distribution.as_global_distribution() | ||
|
||
|
||
class DataParallelDistribution: | ||
def __init__(self, mesh=None, devices=None): | ||
"""Create the data parallel distribution. | ||
User can choose to create this instance by either `Mesh` or `devices` | ||
parameters (but not both). | ||
The mesh is expected to be a `jax.sharding.Mesh` instance, and is | ||
expected to be 1D only. In case that the mesh has multiple axises, then | ||
the first axis will be treated as data parallel dimension (and a warning | ||
will be raised). | ||
When a list of `devices` are provided, they will be used to construct a | ||
1D mesh. | ||
When both `mesh` and `devices` are absent, then we will rely on | ||
`jax.devices` to detect any available devices, and create mesh from | ||
them. | ||
""" | ||
super().__init__() | ||
if mesh: | ||
self._initialize_with_mesh(mesh) | ||
elif devices: | ||
self._initialize_mesh_from_devices(devices) | ||
else: | ||
self._initialize_mesh_from_jax_devices() | ||
|
||
self._configure_sharding_spec() | ||
self._batch_dim_name = self.mesh.axis_names[0] | ||
|
||
@contextlib.contextmanager | ||
def scope(self): | ||
original_scope = global_state.get_global_attribute( | ||
GLOBAL_ATTRIBUTE_NAME | ||
) | ||
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, self) | ||
try: | ||
yield | ||
finally: | ||
global_state.set_global_attribute( | ||
GLOBAL_ATTRIBUTE_NAME, original_scope | ||
) | ||
|
||
def as_global_distribution(self): | ||
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, self) | ||
|
||
def distribute_data(self, data): | ||
return jax.device_put(data, self._data_sharding) | ||
|
||
def distribute_variable(self, variable): | ||
return jax.device_put(variable, self._variable_sharding) | ||
|
||
def _initialize_with_mesh(self, mesh): | ||
if not isinstance(mesh, jax.sharding.Mesh): | ||
raise ValueError( | ||
"Expect the mesh to be type of jax.sharding.Mesh, " | ||
f"Received {type(mesh)}" | ||
) | ||
self._user_provide_devices = None | ||
self.mesh = mesh | ||
if self.mesh.devices.ndim != 1: | ||
logging.warning( | ||
"Expect the input mesh to be 1D, but received %dD. " | ||
"The first axis will be used for data parallel sharding", | ||
self.mesh.devices.ndim, | ||
) | ||
|
||
def _initialize_mesh_from_devices(self, devices): | ||
self._user_provide_devices = devices | ||
self.mesh = jax.sharding.Mesh(np.array(devices), DEFAULT_BATCH_DIM_NAME) | ||
|
||
def _initialize_mesh_from_jax_devices(self): | ||
self._user_provide_devices = None | ||
self.mesh = jax.sharding.Mesh( | ||
np.array(jax.devices()), DEFAULT_BATCH_DIM_NAME | ||
) | ||
|
||
def _configure_sharding_spec(self): | ||
variable_shard_spec = [ | ||
None | ||
] * self.mesh.devices.ndim # Fully replicated | ||
data_shard_spec = variable_shard_spec.copy() | ||
data_shard_spec[0] = self.mesh.axis_names[0] # Shard on the first dim | ||
|
||
self._data_sharding = jax.sharding.NamedSharding( | ||
self.mesh, jax.sharding.PartitionSpec(*data_shard_spec) | ||
) | ||
self._variable_sharding = jax.sharding.NamedSharding( | ||
self.mesh, jax.sharding.PartitionSpec(*variable_shard_spec) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
"""Tests for JAX based distribution.""" | ||
import os | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
|
||
from keras_core import backend | ||
from keras_core import layers | ||
from keras_core import models | ||
from keras_core import testing | ||
from keras_core.backend.common import global_state | ||
from keras_core.backend.jax import distribution | ||
|
||
# Due to https://github.com/google/jax/issues/17188, we can't | ||
# override the XLA flag after the JAX back init. We have to | ||
# run this at top level to let JAX pick the flag value. | ||
xla_flags = os.getenv("XLA_FLAGS") or "" | ||
# Don't override user-specified device count, or other XLA flags. | ||
if "xla_force_host_platform_device_count" not in xla_flags: | ||
os.environ["XLA_FLAGS"] = ( | ||
xla_flags + " --xla_force_host_platform_device_count=8" | ||
) | ||
|
||
|
||
@pytest.mark.skipif( | ||
backend.backend() != "jax", | ||
reason="Only JAX backend support distribution API for now.", | ||
) | ||
class DataParallelDistributionTest(testing.TestCase): | ||
def test_create_with_devices(self): | ||
devices = jax.devices() | ||
self.assertEqual(len(devices), 8) | ||
ds = distribution.DataParallelDistribution(devices=devices) | ||
|
||
mesh = ds.mesh | ||
self.assertEqual(len(mesh.devices), 8) | ||
self.assertEqual(mesh.axis_names, ("batch",)) | ||
self.assertEqual( | ||
ds._data_sharding, | ||
jax.sharding.NamedSharding( | ||
mesh, jax.sharding.PartitionSpec("batch") | ||
), | ||
) | ||
self.assertEqual( | ||
ds._variable_sharding, | ||
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)), | ||
) | ||
|
||
def test_create_with_mesh(self): | ||
mesh = jax.sharding.Mesh(jax.devices(), "data") | ||
ds = distribution.DataParallelDistribution(mesh=mesh) | ||
self.assertEqual(ds.mesh, mesh) | ||
|
||
self.assertEqual( | ||
ds._data_sharding, | ||
jax.sharding.NamedSharding( | ||
mesh, jax.sharding.PartitionSpec("data") | ||
), | ||
) | ||
self.assertEqual( | ||
ds._variable_sharding, | ||
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)), | ||
) | ||
|
||
def test_create_with_available_devices(self): | ||
ds = distribution.DataParallelDistribution() | ||
|
||
mesh = ds.mesh | ||
self.assertEqual(len(mesh.devices), 8) | ||
self.assertEqual(mesh.axis_names, ("batch",)) | ||
|
||
self.assertEqual( | ||
ds._data_sharding, | ||
jax.sharding.NamedSharding( | ||
mesh, jax.sharding.PartitionSpec("batch") | ||
), | ||
) | ||
self.assertEqual( | ||
ds._variable_sharding, | ||
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)), | ||
) | ||
|
||
def test_mesh_with_rank_2(self): | ||
mesh = jax.sharding.Mesh( | ||
np.array(jax.devices()).reshape(4, 2), ("data", "model") | ||
) | ||
ds = distribution.DataParallelDistribution(mesh=mesh) | ||
self.assertEqual(ds.mesh, mesh) | ||
|
||
self.assertEqual( | ||
ds._data_sharding, | ||
jax.sharding.NamedSharding( | ||
mesh, jax.sharding.PartitionSpec("data", None) | ||
), | ||
) | ||
self.assertEqual( | ||
ds._variable_sharding, | ||
jax.sharding.NamedSharding( | ||
mesh, jax.sharding.PartitionSpec(None, None) | ||
), | ||
) | ||
|
||
def test_distribute_data(self): | ||
ds = distribution.DataParallelDistribution() | ||
|
||
data = np.arange(16).reshape((8, 2)) | ||
distributed_data = ds.distribute_data(data) | ||
self.assertEqual(distributed_data.sharding, ds._data_sharding) | ||
|
||
def test_distribute_variable(self): | ||
ds = distribution.DataParallelDistribution() | ||
|
||
weights = np.arange(16).reshape((8, 2)) | ||
distributed_weights = ds.distribute_variable(weights) | ||
self.assertEqual(distributed_weights.sharding, ds._variable_sharding) | ||
|
||
def test_scope(self): | ||
self.assertIsNone(distribution.get_global_distribution()) | ||
data_distribution = distribution.DataParallelDistribution() | ||
with data_distribution.scope(): | ||
self.assertIs( | ||
distribution.get_global_distribution(), data_distribution | ||
) | ||
data_distribution_2 = distribution.DataParallelDistribution() | ||
with data_distribution_2.scope(): | ||
self.assertIs( | ||
distribution.get_global_distribution(), data_distribution_2 | ||
) | ||
|
||
self.assertIs( | ||
distribution.get_global_distribution(), data_distribution | ||
) | ||
|
||
self.assertIsNone(distribution.get_global_distribution()) | ||
|
||
def test_as_global_distribution(self): | ||
try: | ||
self.assertIsNone(distribution.get_global_distribution()) | ||
|
||
data_distribution = distribution.DataParallelDistribution() | ||
data_distribution.as_global_distribution() | ||
self.assertIs( | ||
distribution.get_global_distribution(), data_distribution | ||
) | ||
finally: | ||
# Cleanup the global state | ||
global_state.set_global_attribute( | ||
distribution.GLOBAL_ATTRIBUTE_NAME, None | ||
) | ||
|
||
def test_set_global_distribution(self): | ||
try: | ||
self.assertIsNone(distribution.get_global_distribution()) | ||
|
||
data_distribution = distribution.DataParallelDistribution() | ||
distribution.set_global_distribution(data_distribution) | ||
self.assertIs( | ||
distribution.get_global_distribution(), data_distribution | ||
) | ||
finally: | ||
# Cleanup the global state | ||
global_state.set_global_attribute( | ||
distribution.GLOBAL_ATTRIBUTE_NAME, None | ||
) | ||
|
||
def test_e2e_model(self): | ||
data_distribution = distribution.DataParallelDistribution() | ||
with data_distribution.scope(): | ||
inputs = layers.Input(shape=[28, 28, 1]) | ||
y = layers.Flatten()(inputs) | ||
y = layers.Dense(units=200, use_bias=False, activation="relu")(y) | ||
y = layers.Dropout(0.4)(y) | ||
y = layers.Dense(units=10, activation="softmax")(y) | ||
model = models.Model(inputs=inputs, outputs=y) | ||
|
||
# Make sure all the weights are properly sharded. | ||
for weight in model.weights: | ||
self.assertEqual( | ||
weight._value.sharding, data_distribution._variable_sharding | ||
) | ||
|
||
# TODO(qlzh727): Need to validate the data sharding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters