Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for RaggedTensors for TensorFlow backend #385

Merged
merged 8 commits into from
Jun 26, 2023
48 changes: 47 additions & 1 deletion keras_core/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def get_data(iterator):
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
batch_outputs, potentially_ragged_concat, outputs
)

def train_on_batch(
Expand Down Expand Up @@ -826,3 +826,49 @@ def is_tpu_strat(k):
if is_tpu_strat(clz):
return True
return any(map(_is_tpu_strategy_class, clz.__bases__))


def potentially_ragged_concat(tensors):
"""Concats `Tensor`s along their first dimension.

Args:
tensors: List of `Tensor`s.

Returns:
Concatenation of the inputs along the first dimension -- of type
`Tensor` if all input shapes are compatible, or `RaggedTensor`
if not.
"""
if len(tensors) == 1:
return tensors[0]
if isinstance(tensors[0], tf.SparseTensor):
return tf.sparse.concat(axis=0, sp_inputs=tensors)
elif isinstance(tensors[0], tf.RaggedTensor):
return tf.concat(tensors, axis=0)
elif not tf.__internal__.tf2.enabled():
return tf.concat(tensors, axis=0)

non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors])
constant_dims = tf.math.reduce_all(
non_batch_shapes == non_batch_shapes[:1], axis=0
)
if tf.math.reduce_all(constant_dims).numpy().item():
# All non-batch dims are constant
if _is_scalar(tensors[0]):
return tf.stack(tensors, axis=0)
else:
return tf.concat(tensors, axis=0)

# First, identify constant inner dimensions by finding the
# rightmost dimension that is not constant
constant_inner_dimensions = (
constant_dims.numpy().tolist()[::-1].index(False)
)
# If there are constant inner dimensions, define a constant inner shape
if constant_inner_dimensions == 0:
constant_inner_shape = None
else:
constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:]
return tf.ragged.constant(
[tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape
).merge_dims(0, 1)
11 changes: 8 additions & 3 deletions keras_core/trainers/data_adapters/array_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def convert_single_array(x):
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
elif isinstance(x, pandas.DataFrame):
x = x.to_numpy(dtype=dtype)
if isinstance(x, (tf.Tensor, tf.Variable)):
x = x.numpy()
if isinstance(x, tf.RaggedTensor):
return tf.cast(x, dtype=dtype)
if not isinstance(x, np.ndarray):
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
# `torch.Tensor`, as well as any other tensor-like object that has
Expand All @@ -303,9 +307,10 @@ def convert_single_array(x):
x = np.array(x, dtype=dtype)
else:
raise ValueError(
"Expected a NumPy array, tf.Tensor, jax.np.ndarray, "
"torch.Tensor, Pandas Dataframe, or Pandas Series. "
f"Received invalid input: {x} (of type {type(x)})"
"Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "
"jax.np.ndarray, torch.Tensor, Pandas Dataframe, or "
"Pandas Series. Received invalid input: "
f"{x} (of type {type(x)})"
)
if x.dtype == object:
return x
Expand Down
2 changes: 1 addition & 1 deletion keras_core/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Leave jax, tf, and torch arrays off this list. Instead we will use
# `__array__` to detect these types. Doing so allows us to avoid importing a
# backend framework we are not currently using just to do type-checking.
ARRAY_TYPES = (np.ndarray,)
ARRAY_TYPES = (np.ndarray, tf.RaggedTensor)
if pandas:
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)

Expand Down
72 changes: 72 additions & 0 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,75 @@ def call(self, x):
x = np.ones((16, 2))
y = np.zeros((16, 1))
model.fit(x, y, batch_size=4)

def get_layer(self):
class ExampleLayer(keras_core.Layer):
def call(self, x):
return x * 2

return ExampleLayer

def get_model(self):
class ExampleModel(keras_core.Model):
def call(self, x):
return x * 2

return ExampleModel

def get_functional(self):
ExampleLayer = self.get_layer()

class ExampleFunctional(keras_core.Functional):
def __init__(self, input_shape=(None,)):
inputs = keras_core.Input(input_shape)
outputs = ExampleLayer()(inputs)
super().__init__(inputs=inputs, outputs=outputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to use at least 1 layer to check it actually works. Just make a custom layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!


return ExampleFunctional

@parameterized.named_parameters(
[
{
"testcase_name": "model",
"model_class": "get_model",
},
{
"testcase_name": "layer",
"model_class": "get_layer",
},
{
"testcase_name": "functional",
"model_class": "get_functional",
},
]
)
@pytest.mark.skipif(
keras_core.backend.backend() != "tensorflow",
reason="Only tensorflow supports raggeds",
)
def test_trainer_with_raggeds(self, model_class):
import tensorflow as tf

def loss_fn(y, y_pred, sample_weight=None):
return 0

model = getattr(self, model_class)()()
x = tf.ragged.constant([[1], [2, 3]])

# test forward pass
y = model(x)
self.assertEqual(type(y), tf.RaggedTensor)

# test training
if model_class in ["get_model", "get_functional"]:
model.compile(optimizer="adam", loss=loss_fn)
model.fit(x, x)
y = model.predict(x)
self.assertEqual(type(y), tf.RaggedTensor)

# test if everything works with the sequential model
model = keras_core.Sequential([model])
model.compile(optimizer="adam", loss=loss_fn)
model.fit(x, x)
y = model.predict(x)
self.assertEqual(type(y), tf.RaggedTensor)