Skip to content

Commit

Permalink
Add support for RaggedTensors for TensorFlow backend (#385)
Browse files Browse the repository at this point in the history
* Add support for RaggedTensors for TensorFlow backend

* Add tests for Layer and Functional base classes

* Add a test for the Sequential model

* Move skipif decorator after named_parameters decorator

* Fix the test for the FunctionalModel

* Use a custom layer in Functional test

* Use 4 space indent in docs

* Fix PEP 501: line too long
  • Loading branch information
tirthasheshpatel committed Jun 26, 2023
1 parent 8a6da7e commit 1bf82c3
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 5 deletions.
48 changes: 47 additions & 1 deletion keras_core/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def get_data(iterator):
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
batch_outputs, potentially_ragged_concat, outputs
)

def train_on_batch(
Expand Down Expand Up @@ -826,3 +826,49 @@ def is_tpu_strat(k):
if is_tpu_strat(clz):
return True
return any(map(_is_tpu_strategy_class, clz.__bases__))


def potentially_ragged_concat(tensors):
"""Concats `Tensor`s along their first dimension.
Args:
tensors: List of `Tensor`s.
Returns:
Concatenation of the inputs along the first dimension -- of type
`Tensor` if all input shapes are compatible, or `RaggedTensor`
if not.
"""
if len(tensors) == 1:
return tensors[0]
if isinstance(tensors[0], tf.SparseTensor):
return tf.sparse.concat(axis=0, sp_inputs=tensors)
elif isinstance(tensors[0], tf.RaggedTensor):
return tf.concat(tensors, axis=0)
elif not tf.__internal__.tf2.enabled():
return tf.concat(tensors, axis=0)

non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors])
constant_dims = tf.math.reduce_all(
non_batch_shapes == non_batch_shapes[:1], axis=0
)
if tf.math.reduce_all(constant_dims).numpy().item():
# All non-batch dims are constant
if _is_scalar(tensors[0]):
return tf.stack(tensors, axis=0)
else:
return tf.concat(tensors, axis=0)

# First, identify constant inner dimensions by finding the
# rightmost dimension that is not constant
constant_inner_dimensions = (
constant_dims.numpy().tolist()[::-1].index(False)
)
# If there are constant inner dimensions, define a constant inner shape
if constant_inner_dimensions == 0:
constant_inner_shape = None
else:
constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:]
return tf.ragged.constant(
[tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape
).merge_dims(0, 1)
11 changes: 8 additions & 3 deletions keras_core/trainers/data_adapters/array_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def convert_single_array(x):
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
elif isinstance(x, pandas.DataFrame):
x = x.to_numpy(dtype=dtype)
if isinstance(x, (tf.Tensor, tf.Variable)):
x = x.numpy()
if isinstance(x, tf.RaggedTensor):
return tf.cast(x, dtype=dtype)
if not isinstance(x, np.ndarray):
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
# `torch.Tensor`, as well as any other tensor-like object that has
Expand All @@ -303,9 +307,10 @@ def convert_single_array(x):
x = np.array(x, dtype=dtype)
else:
raise ValueError(
"Expected a NumPy array, tf.Tensor, jax.np.ndarray, "
"torch.Tensor, Pandas Dataframe, or Pandas Series. "
f"Received invalid input: {x} (of type {type(x)})"
"Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "
"jax.np.ndarray, torch.Tensor, Pandas Dataframe, or "
"Pandas Series. Received invalid input: "
f"{x} (of type {type(x)})"
)
if x.dtype == object:
return x
Expand Down
2 changes: 1 addition & 1 deletion keras_core/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Leave jax, tf, and torch arrays off this list. Instead we will use
# `__array__` to detect these types. Doing so allows us to avoid importing a
# backend framework we are not currently using just to do type-checking.
ARRAY_TYPES = (np.ndarray,)
ARRAY_TYPES = (np.ndarray, tf.RaggedTensor)
if pandas:
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)

Expand Down
72 changes: 72 additions & 0 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,75 @@ def call(self, x):
x = np.ones((16, 2))
y = np.zeros((16, 1))
model.fit(x, y, batch_size=4)

def get_layer(self):
class ExampleLayer(keras_core.Layer):
def call(self, x):
return x * 2

return ExampleLayer

def get_model(self):
class ExampleModel(keras_core.Model):
def call(self, x):
return x * 2

return ExampleModel

def get_functional(self):
ExampleLayer = self.get_layer()

class ExampleFunctional(keras_core.Functional):
def __init__(self, input_shape=(None,)):
inputs = keras_core.Input(input_shape)
outputs = ExampleLayer()(inputs)
super().__init__(inputs=inputs, outputs=outputs)

return ExampleFunctional

@parameterized.named_parameters(
[
{
"testcase_name": "model",
"model_class": "get_model",
},
{
"testcase_name": "layer",
"model_class": "get_layer",
},
{
"testcase_name": "functional",
"model_class": "get_functional",
},
]
)
@pytest.mark.skipif(
keras_core.backend.backend() != "tensorflow",
reason="Only tensorflow supports raggeds",
)
def test_trainer_with_raggeds(self, model_class):
import tensorflow as tf

def loss_fn(y, y_pred, sample_weight=None):
return 0

model = getattr(self, model_class)()()
x = tf.ragged.constant([[1], [2, 3]])

# test forward pass
y = model(x)
self.assertEqual(type(y), tf.RaggedTensor)

# test training
if model_class in ["get_model", "get_functional"]:
model.compile(optimizer="adam", loss=loss_fn)
model.fit(x, x)
y = model.predict(x)
self.assertEqual(type(y), tf.RaggedTensor)

# test if everything works with the sequential model
model = keras_core.Sequential([model])
model.compile(optimizer="adam", loss=loss_fn)
model.fit(x, x)
y = model.predict(x)
self.assertEqual(type(y), tf.RaggedTensor)

0 comments on commit 1bf82c3

Please sign in to comment.