From acc1f662d3771dec1e2ef58dfa0b72296b68859f Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Sat, 19 Aug 2023 15:39:42 -0700 Subject: [PATCH] Add a high level API for distribution for JAX backend. (#741) * Init commit for distributed training with JAX. * WIP * Add unit test for data parallel distribution. * Update the TODO message in the test. * Reduce he scope of the XLA flag in unit test. * Update unit test for setup/teardown * Add JAX xla backend reset for unit test * Multiple updates to the distribution. 1. Rename the distribute.py to distribution.py (same for tests). 2. Merge the global state logic to the distribution class. 3. Update all the unit tests. * Updates. * Updates * Formatting * Update test for debug setup/teardown * Further debug for unit test * lift the config logic out for testing * More debug for the unit test cleanup * Fix the unit test with warning. * Address review comments. * Address review comments. --- keras_core/backend/jax/core.py | 8 +- keras_core/backend/jax/distribution.py | 121 +++++++++++++ keras_core/backend/jax/distribution_test.py | 183 ++++++++++++++++++++ keras_core/backend/jax/trainer.py | 10 ++ 4 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 keras_core/backend/jax/distribution.py create mode 100644 keras_core/backend/jax/distribution_test.py diff --git a/keras_core/backend/jax/core.py b/keras_core/backend/jax/core.py index 2c50e48e16c..d4148dbaf0d 100644 --- a/keras_core/backend/jax/core.py +++ b/keras_core/backend/jax/core.py @@ -10,6 +10,7 @@ from keras_core.backend.common import standardize_dtype from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.backend.common.stateless_scope import StatelessScope +from keras_core.backend.jax import distribution from keras_core.utils.nest import pack_sequence_as DYNAMIC_SHAPES_OK = True @@ -17,7 +18,12 @@ class Variable(KerasVariable): def _initialize(self, value): - self._value = jnp.array(value, dtype=self._dtype) + value = jnp.array(value, dtype=self._dtype) + if distribution.get_global_distribution() is not None: + value = distribution.get_global_distribution().distribute_variable( + value + ) + self._value = value def _direct_assign(self, value): self._value = value diff --git a/keras_core/backend/jax/distribution.py b/keras_core/backend/jax/distribution.py new file mode 100644 index 00000000000..73aef883866 --- /dev/null +++ b/keras_core/backend/jax/distribution.py @@ -0,0 +1,121 @@ +"""!!!DO NOT USE!!! + +Distribution related class for JAX backend. + +This is just a prototype and we might want to unify it in future for other +backends. +""" + +import contextlib + +from absl import logging +import jax +import numpy as np + +from keras_core.backend.common import global_state + +DEFAULT_BATCH_DIM_NAME = "batch" +GLOBAL_ATTRIBUTE_NAME = "distribution" + + +def get_global_distribution(): + """Retrieve the current distribution from global context.""" + return global_state.get_global_attribute(GLOBAL_ATTRIBUTE_NAME) + + +def set_global_distribution(distribution): + """Set the distribution as the global distribution setting.""" + # TODO(qlzh727): Do type checking for input once we have a base class. + distribution.as_global_distribution() + + +class DataParallelDistribution: + def __init__(self, mesh=None, devices=None): + """Create the data parallel distribution. + + User can choose to create this instance by either `Mesh` or `devices` + parameters (but not both). + + The mesh is expected to be a `jax.sharding.Mesh` instance, and is + expected to be 1D only. In case that the mesh has multiple axises, then + the first axis will be treated as data parallel dimension (and a warning + will be raised). + + When a list of `devices` are provided, they will be used to construct a + 1D mesh. + + When both `mesh` and `devices` are absent, then we will rely on + `jax.devices` to detect any available devices, and create mesh from + them. + """ + super().__init__() + if mesh: + self._initialize_with_mesh(mesh) + elif devices: + self._initialize_mesh_from_devices(devices) + else: + self._initialize_mesh_from_jax_devices() + + self._configure_sharding_spec() + self._batch_dim_name = self.mesh.axis_names[0] + + @contextlib.contextmanager + def scope(self): + original_scope = global_state.get_global_attribute( + GLOBAL_ATTRIBUTE_NAME + ) + global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, self) + try: + yield + finally: + global_state.set_global_attribute( + GLOBAL_ATTRIBUTE_NAME, original_scope + ) + + def as_global_distribution(self): + global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, self) + + def distribute_data(self, data): + return jax.device_put(data, self._data_sharding) + + def distribute_variable(self, variable): + return jax.device_put(variable, self._variable_sharding) + + def _initialize_with_mesh(self, mesh): + if not isinstance(mesh, jax.sharding.Mesh): + raise ValueError( + "Expect the mesh to be type of jax.sharding.Mesh, " + f"Received {type(mesh)}" + ) + self._user_provide_devices = None + self.mesh = mesh + if self.mesh.devices.ndim != 1: + logging.warning( + "Expect the input mesh to be 1D, but received %dD. " + "The first axis will be used for data parallel sharding", + self.mesh.devices.ndim, + ) + + def _initialize_mesh_from_devices(self, devices): + self._user_provide_devices = devices + self.mesh = jax.sharding.Mesh(np.array(devices), DEFAULT_BATCH_DIM_NAME) + + def _initialize_mesh_from_jax_devices(self): + self._user_provide_devices = None + self.mesh = jax.sharding.Mesh( + np.array(jax.devices()), DEFAULT_BATCH_DIM_NAME + ) + + def _configure_sharding_spec(self): + variable_shard_spec = [ + None + ] * self.mesh.devices.ndim # Fully replicated + data_shard_spec = variable_shard_spec.copy() + data_shard_spec[0] = self.mesh.axis_names[0] # Shard on the first dim + + self._data_sharding = jax.sharding.NamedSharding( + self.mesh, jax.sharding.PartitionSpec(*data_shard_spec) + ) + self._variable_sharding = jax.sharding.NamedSharding( + self.mesh, jax.sharding.PartitionSpec(*variable_shard_spec) + ) diff --git a/keras_core/backend/jax/distribution_test.py b/keras_core/backend/jax/distribution_test.py new file mode 100644 index 00000000000..9b95d56517b --- /dev/null +++ b/keras_core/backend/jax/distribution_test.py @@ -0,0 +1,183 @@ +"""Tests for JAX based distribution.""" +import os + +import jax +import numpy as np +import pytest + +from keras_core import backend +from keras_core import layers +from keras_core import models +from keras_core import testing +from keras_core.backend.common import global_state +from keras_core.backend.jax import distribution + +# Due to https://github.com/google/jax/issues/17188, we can't +# override the XLA flag after the JAX back init. We have to +# run this at top level to let JAX pick the flag value. +xla_flags = os.getenv("XLA_FLAGS") or "" +# Don't override user-specified device count, or other XLA flags. +if "xla_force_host_platform_device_count" not in xla_flags: + os.environ["XLA_FLAGS"] = ( + xla_flags + " --xla_force_host_platform_device_count=8" + ) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Only JAX backend support distribution API for now.", +) +class DataParallelDistributionTest(testing.TestCase): + def test_create_with_devices(self): + devices = jax.devices() + self.assertEqual(len(devices), 8) + ds = distribution.DataParallelDistribution(devices=devices) + + mesh = ds.mesh + self.assertEqual(len(mesh.devices), 8) + self.assertEqual(mesh.axis_names, ("batch",)) + self.assertEqual( + ds._data_sharding, + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("batch") + ), + ) + self.assertEqual( + ds._variable_sharding, + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)), + ) + + def test_create_with_mesh(self): + mesh = jax.sharding.Mesh(jax.devices(), "data") + ds = distribution.DataParallelDistribution(mesh=mesh) + self.assertEqual(ds.mesh, mesh) + + self.assertEqual( + ds._data_sharding, + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("data") + ), + ) + self.assertEqual( + ds._variable_sharding, + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)), + ) + + def test_create_with_available_devices(self): + ds = distribution.DataParallelDistribution() + + mesh = ds.mesh + self.assertEqual(len(mesh.devices), 8) + self.assertEqual(mesh.axis_names, ("batch",)) + + self.assertEqual( + ds._data_sharding, + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("batch") + ), + ) + self.assertEqual( + ds._variable_sharding, + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)), + ) + + def test_mesh_with_rank_2(self): + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(4, 2), ("data", "model") + ) + ds = distribution.DataParallelDistribution(mesh=mesh) + self.assertEqual(ds.mesh, mesh) + + self.assertEqual( + ds._data_sharding, + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("data", None) + ), + ) + self.assertEqual( + ds._variable_sharding, + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(None, None) + ), + ) + + def test_distribute_data(self): + ds = distribution.DataParallelDistribution() + + data = np.arange(16).reshape((8, 2)) + distributed_data = ds.distribute_data(data) + self.assertEqual(distributed_data.sharding, ds._data_sharding) + + def test_distribute_variable(self): + ds = distribution.DataParallelDistribution() + + weights = np.arange(16).reshape((8, 2)) + distributed_weights = ds.distribute_variable(weights) + self.assertEqual(distributed_weights.sharding, ds._variable_sharding) + + def test_scope(self): + self.assertIsNone(distribution.get_global_distribution()) + data_distribution = distribution.DataParallelDistribution() + with data_distribution.scope(): + self.assertIs( + distribution.get_global_distribution(), data_distribution + ) + data_distribution_2 = distribution.DataParallelDistribution() + with data_distribution_2.scope(): + self.assertIs( + distribution.get_global_distribution(), data_distribution_2 + ) + + self.assertIs( + distribution.get_global_distribution(), data_distribution + ) + + self.assertIsNone(distribution.get_global_distribution()) + + def test_as_global_distribution(self): + try: + self.assertIsNone(distribution.get_global_distribution()) + + data_distribution = distribution.DataParallelDistribution() + data_distribution.as_global_distribution() + self.assertIs( + distribution.get_global_distribution(), data_distribution + ) + finally: + # Cleanup the global state + global_state.set_global_attribute( + distribution.GLOBAL_ATTRIBUTE_NAME, None + ) + + def test_set_global_distribution(self): + try: + self.assertIsNone(distribution.get_global_distribution()) + + data_distribution = distribution.DataParallelDistribution() + distribution.set_global_distribution(data_distribution) + self.assertIs( + distribution.get_global_distribution(), data_distribution + ) + finally: + # Cleanup the global state + global_state.set_global_attribute( + distribution.GLOBAL_ATTRIBUTE_NAME, None + ) + + def test_e2e_model(self): + data_distribution = distribution.DataParallelDistribution() + with data_distribution.scope(): + inputs = layers.Input(shape=[28, 28, 1]) + y = layers.Flatten()(inputs) + y = layers.Dense(units=200, use_bias=False, activation="relu")(y) + y = layers.Dropout(0.4)(y) + y = layers.Dense(units=10, activation="softmax")(y) + model = models.Model(inputs=inputs, outputs=y) + + # Make sure all the weights are properly sharded. + for weight in model.weights: + self.assertEqual( + weight._value.sharding, data_distribution._variable_sharding + ) + + # TODO(qlzh727): Need to validate the data sharding diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index 322f27667f7..07cad737098 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -6,6 +6,7 @@ from keras_core import callbacks as callbacks_module from keras_core import ops from keras_core import optimizers as optimizers_module +from keras_core.backend.jax import distribution from keras_core.trainers import trainer as base_trainer from keras_core.trainers.data_adapters import data_adapter_utils from keras_core.trainers.epoch_iterator import EpochIterator @@ -77,6 +78,7 @@ def train_step(self, state, data): metrics_variables, ) = state x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + x, y, sample_weight = self._distribute_data((x, y, sample_weight)) grad_fn = jax.value_and_grad( self.compute_loss_and_updates, has_aux=True ) @@ -128,6 +130,7 @@ def test_step(self, state, data): metrics_variables, ) = state x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + x, y, sample_weight = self._distribute_data((x, y, sample_weight)) loss, ( y_pred, non_trainable_variables, @@ -171,6 +174,7 @@ def predict_step(self, state, data): kwargs["training"] = False x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + x = self._distribute_data(x) outputs, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, **kwargs ) @@ -747,3 +751,9 @@ def jax_state_sync(self): if metrics_variables: for ref_v, v in zip(self.metrics_variables, metrics_variables): ref_v.assign(v) + + def _distribute_data(self, data): + if distribution.get_global_distribution() is not None: + distribute = distribution.get_global_distribution() + data = distribute.distribute_data(data) + return data