Skip to content

Commit

Permalink
Add a high level API for distribution for JAX backend. (#741)
Browse files Browse the repository at this point in the history
* Init commit for distributed training with JAX.

* WIP

* Add unit test for data parallel distribution.

* Update the TODO message in the test.

* Reduce he scope of the XLA flag in unit test.

* Update unit test for setup/teardown

* Add JAX xla backend reset for unit test

* Multiple updates to the distribution.

1. Rename  the distribute.py to distribution.py (same for tests).
2. Merge the global state logic to the distribution class.
3. Update all the unit tests.

* Updates.

* Updates

* Formatting

* Update test for debug setup/teardown

* Further debug for unit test

* lift the config logic out for testing

* More debug for the unit test cleanup

* Fix the unit test with warning.

* Address review comments.

* Address review comments.
  • Loading branch information
qlzh727 authored and fchollet committed Aug 19, 2023
1 parent 8e99fcf commit acc1f66
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 1 deletion.
8 changes: 7 additions & 1 deletion keras_core/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.jax import distribution
from keras_core.utils.nest import pack_sequence_as

DYNAMIC_SHAPES_OK = True


class Variable(KerasVariable):
def _initialize(self, value):
self._value = jnp.array(value, dtype=self._dtype)
value = jnp.array(value, dtype=self._dtype)
if distribution.get_global_distribution() is not None:
value = distribution.get_global_distribution().distribute_variable(
value
)
self._value = value

def _direct_assign(self, value):
self._value = value
Expand Down
121 changes: 121 additions & 0 deletions keras_core/backend/jax/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""!!!DO NOT USE!!!
Distribution related class for JAX backend.
This is just a prototype and we might want to unify it in future for other
backends.
"""

import contextlib

from absl import logging
import jax
import numpy as np

from keras_core.backend.common import global_state

DEFAULT_BATCH_DIM_NAME = "batch"
GLOBAL_ATTRIBUTE_NAME = "distribution"


def get_global_distribution():
"""Retrieve the current distribution from global context."""
return global_state.get_global_attribute(GLOBAL_ATTRIBUTE_NAME)


def set_global_distribution(distribution):
"""Set the distribution as the global distribution setting."""
# TODO(qlzh727): Do type checking for input once we have a base class.
distribution.as_global_distribution()


class DataParallelDistribution:
def __init__(self, mesh=None, devices=None):
"""Create the data parallel distribution.
User can choose to create this instance by either `Mesh` or `devices`
parameters (but not both).
The mesh is expected to be a `jax.sharding.Mesh` instance, and is
expected to be 1D only. In case that the mesh has multiple axises, then
the first axis will be treated as data parallel dimension (and a warning
will be raised).
When a list of `devices` are provided, they will be used to construct a
1D mesh.
When both `mesh` and `devices` are absent, then we will rely on
`jax.devices` to detect any available devices, and create mesh from
them.
"""
super().__init__()
if mesh:
self._initialize_with_mesh(mesh)
elif devices:
self._initialize_mesh_from_devices(devices)
else:
self._initialize_mesh_from_jax_devices()

self._configure_sharding_spec()
self._batch_dim_name = self.mesh.axis_names[0]

@contextlib.contextmanager
def scope(self):
original_scope = global_state.get_global_attribute(
GLOBAL_ATTRIBUTE_NAME
)
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, self)
try:
yield
finally:
global_state.set_global_attribute(
GLOBAL_ATTRIBUTE_NAME, original_scope
)

def as_global_distribution(self):
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, self)

def distribute_data(self, data):
return jax.device_put(data, self._data_sharding)

def distribute_variable(self, variable):
return jax.device_put(variable, self._variable_sharding)

def _initialize_with_mesh(self, mesh):
if not isinstance(mesh, jax.sharding.Mesh):
raise ValueError(
"Expect the mesh to be type of jax.sharding.Mesh, "
f"Received {type(mesh)}"
)
self._user_provide_devices = None
self.mesh = mesh
if self.mesh.devices.ndim != 1:
logging.warning(
"Expect the input mesh to be 1D, but received %dD. "
"The first axis will be used for data parallel sharding",
self.mesh.devices.ndim,
)

def _initialize_mesh_from_devices(self, devices):
self._user_provide_devices = devices
self.mesh = jax.sharding.Mesh(np.array(devices), DEFAULT_BATCH_DIM_NAME)

def _initialize_mesh_from_jax_devices(self):
self._user_provide_devices = None
self.mesh = jax.sharding.Mesh(
np.array(jax.devices()), DEFAULT_BATCH_DIM_NAME
)

def _configure_sharding_spec(self):
variable_shard_spec = [
None
] * self.mesh.devices.ndim # Fully replicated
data_shard_spec = variable_shard_spec.copy()
data_shard_spec[0] = self.mesh.axis_names[0] # Shard on the first dim

self._data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec(*data_shard_spec)
)
self._variable_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec(*variable_shard_spec)
)
183 changes: 183 additions & 0 deletions keras_core/backend/jax/distribution_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Tests for JAX based distribution."""
import os

import jax
import numpy as np
import pytest

from keras_core import backend
from keras_core import layers
from keras_core import models
from keras_core import testing
from keras_core.backend.common import global_state
from keras_core.backend.jax import distribution

# Due to https://github.com/google/jax/issues/17188, we can't
# override the XLA flag after the JAX back init. We have to
# run this at top level to let JAX pick the flag value.
xla_flags = os.getenv("XLA_FLAGS") or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in xla_flags:
os.environ["XLA_FLAGS"] = (
xla_flags + " --xla_force_host_platform_device_count=8"
)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="Only JAX backend support distribution API for now.",
)
class DataParallelDistributionTest(testing.TestCase):
def test_create_with_devices(self):
devices = jax.devices()
self.assertEqual(len(devices), 8)
ds = distribution.DataParallelDistribution(devices=devices)

mesh = ds.mesh
self.assertEqual(len(mesh.devices), 8)
self.assertEqual(mesh.axis_names, ("batch",))
self.assertEqual(
ds._data_sharding,
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("batch")
),
)
self.assertEqual(
ds._variable_sharding,
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)),
)

def test_create_with_mesh(self):
mesh = jax.sharding.Mesh(jax.devices(), "data")
ds = distribution.DataParallelDistribution(mesh=mesh)
self.assertEqual(ds.mesh, mesh)

self.assertEqual(
ds._data_sharding,
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data")
),
)
self.assertEqual(
ds._variable_sharding,
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)),
)

def test_create_with_available_devices(self):
ds = distribution.DataParallelDistribution()

mesh = ds.mesh
self.assertEqual(len(mesh.devices), 8)
self.assertEqual(mesh.axis_names, ("batch",))

self.assertEqual(
ds._data_sharding,
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("batch")
),
)
self.assertEqual(
ds._variable_sharding,
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)),
)

def test_mesh_with_rank_2(self):
mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(4, 2), ("data", "model")
)
ds = distribution.DataParallelDistribution(mesh=mesh)
self.assertEqual(ds.mesh, mesh)

self.assertEqual(
ds._data_sharding,
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data", None)
),
)
self.assertEqual(
ds._variable_sharding,
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(None, None)
),
)

def test_distribute_data(self):
ds = distribution.DataParallelDistribution()

data = np.arange(16).reshape((8, 2))
distributed_data = ds.distribute_data(data)
self.assertEqual(distributed_data.sharding, ds._data_sharding)

def test_distribute_variable(self):
ds = distribution.DataParallelDistribution()

weights = np.arange(16).reshape((8, 2))
distributed_weights = ds.distribute_variable(weights)
self.assertEqual(distributed_weights.sharding, ds._variable_sharding)

def test_scope(self):
self.assertIsNone(distribution.get_global_distribution())
data_distribution = distribution.DataParallelDistribution()
with data_distribution.scope():
self.assertIs(
distribution.get_global_distribution(), data_distribution
)
data_distribution_2 = distribution.DataParallelDistribution()
with data_distribution_2.scope():
self.assertIs(
distribution.get_global_distribution(), data_distribution_2
)

self.assertIs(
distribution.get_global_distribution(), data_distribution
)

self.assertIsNone(distribution.get_global_distribution())

def test_as_global_distribution(self):
try:
self.assertIsNone(distribution.get_global_distribution())

data_distribution = distribution.DataParallelDistribution()
data_distribution.as_global_distribution()
self.assertIs(
distribution.get_global_distribution(), data_distribution
)
finally:
# Cleanup the global state
global_state.set_global_attribute(
distribution.GLOBAL_ATTRIBUTE_NAME, None
)

def test_set_global_distribution(self):
try:
self.assertIsNone(distribution.get_global_distribution())

data_distribution = distribution.DataParallelDistribution()
distribution.set_global_distribution(data_distribution)
self.assertIs(
distribution.get_global_distribution(), data_distribution
)
finally:
# Cleanup the global state
global_state.set_global_attribute(
distribution.GLOBAL_ATTRIBUTE_NAME, None
)

def test_e2e_model(self):
data_distribution = distribution.DataParallelDistribution()
with data_distribution.scope():
inputs = layers.Input(shape=[28, 28, 1])
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = models.Model(inputs=inputs, outputs=y)

# Make sure all the weights are properly sharded.
for weight in model.weights:
self.assertEqual(
weight._value.sharding, data_distribution._variable_sharding
)

# TODO(qlzh727): Need to validate the data sharding
10 changes: 10 additions & 0 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras_core import callbacks as callbacks_module
from keras_core import ops
from keras_core import optimizers as optimizers_module
from keras_core.backend.jax import distribution
from keras_core.trainers import trainer as base_trainer
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
Expand Down Expand Up @@ -77,6 +78,7 @@ def train_step(self, state, data):
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x, y, sample_weight = self._distribute_data((x, y, sample_weight))
grad_fn = jax.value_and_grad(
self.compute_loss_and_updates, has_aux=True
)
Expand Down Expand Up @@ -128,6 +130,7 @@ def test_step(self, state, data):
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x, y, sample_weight = self._distribute_data((x, y, sample_weight))
loss, (
y_pred,
non_trainable_variables,
Expand Down Expand Up @@ -171,6 +174,7 @@ def predict_step(self, state, data):
kwargs["training"] = False

x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
x = self._distribute_data(x)
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
Expand Down Expand Up @@ -747,3 +751,9 @@ def jax_state_sync(self):
if metrics_variables:
for ref_v, v in zip(self.metrics_variables, metrics_variables):
ref_v.assign(v)

def _distribute_data(self, data):
if distribution.get_global_distribution() is not None:
distribute = distribution.get_global_distribution()
data = distribute.distribute_data(data)
return data

0 comments on commit acc1f66

Please sign in to comment.