From 632eb7c57d45561cd862646b2ca75364b3c79beb Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 21 Jun 2023 20:11:01 +0000 Subject: [PATCH 1/8] Add support for RaggedTensors for TensorFlow backend --- keras_core/backend/tensorflow/trainer.py | 49 ++++++++++++++++++- .../data_adapters/array_data_adapter.py | 11 +++-- .../data_adapters/data_adapter_utils.py | 2 +- keras_core/trainers/trainer_test.py | 21 ++++++++ 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 7187334b7..b473df62f 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -487,7 +487,7 @@ def get_data(iterator): callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) callbacks.on_predict_end() return tf.__internal__.nest.map_structure_up_to( - batch_outputs, np.concatenate, outputs + batch_outputs, potentially_ragged_concat, outputs ) def train_on_batch( @@ -826,3 +826,50 @@ def is_tpu_strat(k): if is_tpu_strat(clz): return True return any(map(_is_tpu_strategy_class, clz.__bases__)) + + +# This function is taken from keras.engine.training.potentially_ragged_concat +# Source: https://github.com/keras-team/keras/blob/5f9d052566f92821b323dabb111eb4679624ddfb/keras/engine/training.py#L4207-#L4249 # noqa: E501 +def potentially_ragged_concat(tensors): + """Concats `Tensor`s along their first dimension. + + Args: + tensors: List of `Tensor`s. + + Returns: + Concatenation of the inputs along the first dimension -- of type `Tensor` + if all input shapes are compatible, or `RaggedTensor` if not. + """ + if len(tensors) == 1: + return tensors[0] + if isinstance(tensors[0], tf.SparseTensor): + return tf.sparse.concat(axis=0, sp_inputs=tensors) + elif isinstance(tensors[0], tf.RaggedTensor): + return tf.concat(tensors, axis=0) + elif not tf.__internal__.tf2.enabled(): + return tf.concat(tensors, axis=0) + + non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors]) + constant_dims = tf.math.reduce_all( + non_batch_shapes == non_batch_shapes[:1], axis=0 + ) + if tf.math.reduce_all(constant_dims).numpy().item(): + # All non-batch dims are constant + if _is_scalar(tensors[0]): + return tf.stack(tensors, axis=0) + else: + return tf.concat(tensors, axis=0) + + # First, identify constant inner dimensions by finding the + # rightmost dimension that is not constant + constant_inner_dimensions = ( + constant_dims.numpy().tolist()[::-1].index(False) + ) + # If there are constant inner dimensions, define a constant inner shape + if constant_inner_dimensions == 0: + constant_inner_shape = None + else: + constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:] + return tf.ragged.constant( + [tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape + ).merge_dims(0, 1) diff --git a/keras_core/trainers/data_adapters/array_data_adapter.py b/keras_core/trainers/data_adapters/array_data_adapter.py index 679d5a1a5..f796e51a2 100644 --- a/keras_core/trainers/data_adapters/array_data_adapter.py +++ b/keras_core/trainers/data_adapters/array_data_adapter.py @@ -295,6 +295,10 @@ def convert_single_array(x): x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1) elif isinstance(x, pandas.DataFrame): x = x.to_numpy(dtype=dtype) + if isinstance(x, (tf.Tensor, tf.Variable)): + x = x.numpy() + if isinstance(x, tf.RaggedTensor): + return tf.cast(x, dtype=dtype) if not isinstance(x, np.ndarray): # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`, # `torch.Tensor`, as well as any other tensor-like object that has @@ -303,9 +307,10 @@ def convert_single_array(x): x = np.array(x, dtype=dtype) else: raise ValueError( - "Expected a NumPy array, tf.Tensor, jax.np.ndarray, " - "torch.Tensor, Pandas Dataframe, or Pandas Series. " - f"Received invalid input: {x} (of type {type(x)})" + "Expected a NumPy array, tf.Tensor, tf.RaggedTensor, " + "jax.np.ndarray, torch.Tensor, Pandas Dataframe, or " + "Pandas Series. Received invalid input: " + f"{x} (of type {type(x)})" ) if x.dtype == object: return x diff --git a/keras_core/trainers/data_adapters/data_adapter_utils.py b/keras_core/trainers/data_adapters/data_adapter_utils.py index 00d4c7800..32d7c48cb 100644 --- a/keras_core/trainers/data_adapters/data_adapter_utils.py +++ b/keras_core/trainers/data_adapters/data_adapter_utils.py @@ -14,7 +14,7 @@ # Leave jax, tf, and torch arrays off this list. Instead we will use # `__array__` to detect these types. Doing so allows us to avoid importing a # backend framework we are not currently using just to do type-checking. -ARRAY_TYPES = (np.ndarray,) +ARRAY_TYPES = (np.ndarray, tf.RaggedTensor) if pandas: ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 46c6143b0..5fcbbb0b4 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -450,3 +450,24 @@ def call(self, x): x = np.ones((16, 2)) y = np.zeros((16, 1)) model.fit(x, y, batch_size=4) + + @pytest.mark.skipif( + keras_core.backend.backend() != "tensorflow", + reason="Only tensorflow supports raggeds", + ) + def test_trainer_with_raggeds(self): + import tensorflow as tf + + class ExampleModel(keras_core.Model): + def call(self, x): + return 2 * x + + def compute_loss(self, x, y, y_pred, sample_weight=None): + return 0 + + model = ExampleModel() + x = tf.ragged.constant([[1], [2, 3]]) + model.compile(optimizer="adam") + model.fit(x, x) + y = model.predict(x) + self.assertEqual(type(y), tf.RaggedTensor) From f47f52476a4a9e7c9ddc5400429ecdb0cdc91f9a Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Fri, 23 Jun 2023 22:49:03 +0000 Subject: [PATCH 2/8] Add tests for Layer and Functional base classes --- keras_core/backend/tensorflow/trainer.py | 2 -- keras_core/trainers/trainer_test.py | 41 +++++++++++++++++++++--- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index b473df62f..0e22c6ebb 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -828,8 +828,6 @@ def is_tpu_strat(k): return any(map(_is_tpu_strategy_class, clz.__bases__)) -# This function is taken from keras.engine.training.potentially_ragged_concat -# Source: https://github.com/keras-team/keras/blob/5f9d052566f92821b323dabb111eb4679624ddfb/keras/engine/training.py#L4207-#L4249 # noqa: E501 def potentially_ragged_concat(tensors): """Concats `Tensor`s along their first dimension. diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 5fcbbb0b4..cf3741195 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -455,10 +455,34 @@ def call(self, x): keras_core.backend.backend() != "tensorflow", reason="Only tensorflow supports raggeds", ) - def test_trainer_with_raggeds(self): + @parameterized.named_parameters( + [ + { + "testcase_name": "base_class_model", + "base_class": keras_core.Model, + }, + { + "testcase_name": "base_class_layer", + "base_class": keras_core.Layer, + }, + { + "testcase_name": "base_class_functional", + "base_class": keras_core.Functional, + }, + ] + ) + def test_trainer_with_raggeds(self, base_class): import tensorflow as tf - class ExampleModel(keras_core.Model): + class ExampleModel(base_class): + def __init__(self, input_shape=(None,)): + if base_class is keras_core.Functional: + inputs = keras_core.Input(input_shape) + outputs = inputs * 2 + super().__init__(inputs=inputs, outputs=outputs) + else: + super().__init__() + def call(self, x): return 2 * x @@ -467,7 +491,14 @@ def compute_loss(self, x, y, y_pred, sample_weight=None): model = ExampleModel() x = tf.ragged.constant([[1], [2, 3]]) - model.compile(optimizer="adam") - model.fit(x, x) - y = model.predict(x) + + # test forward pass + y = model(x) self.assertEqual(type(y), tf.RaggedTensor) + + # test training + if base_class in [keras_core.Model, keras_core.Functional]: + model.compile(optimizer="adam") + model.fit(x, x) + y = model.predict(x) + self.assertEqual(type(y), tf.RaggedTensor) From 3a9f0698d631f93123491740d728a631a84f1936 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Fri, 23 Jun 2023 22:57:34 +0000 Subject: [PATCH 3/8] Add a test for the Sequential model --- keras_core/trainers/trainer_test.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index cf3741195..525118892 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -486,8 +486,8 @@ def __init__(self, input_shape=(None,)): def call(self, x): return 2 * x - def compute_loss(self, x, y, y_pred, sample_weight=None): - return 0 + def loss_fn(y, y_pred, sample_weight=None): + return 0 model = ExampleModel() x = tf.ragged.constant([[1], [2, 3]]) @@ -498,7 +498,14 @@ def compute_loss(self, x, y, y_pred, sample_weight=None): # test training if base_class in [keras_core.Model, keras_core.Functional]: - model.compile(optimizer="adam") + model.compile(optimizer="adam", loss=loss_fn) model.fit(x, x) y = model.predict(x) self.assertEqual(type(y), tf.RaggedTensor) + + # test if everything works with the sequential model + model = keras_core.Sequential([model]) + model.compile(optimizer="adam", loss=loss_fn) + model.fit(x, x) + y = model.predict(x) + self.assertEqual(type(y), tf.RaggedTensor) From 3ff71d6225104452c4f3fae59db103b264751b78 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Fri, 23 Jun 2023 23:00:25 +0000 Subject: [PATCH 4/8] Move skipif decorator after named_parameters decorator --- keras_core/trainers/trainer_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 525118892..8ac7b7cfa 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -451,10 +451,6 @@ def call(self, x): y = np.zeros((16, 1)) model.fit(x, y, batch_size=4) - @pytest.mark.skipif( - keras_core.backend.backend() != "tensorflow", - reason="Only tensorflow supports raggeds", - ) @parameterized.named_parameters( [ { @@ -471,6 +467,10 @@ def call(self, x): }, ] ) + @pytest.mark.skipif( + keras_core.backend.backend() != "tensorflow", + reason="Only tensorflow supports raggeds", + ) def test_trainer_with_raggeds(self, base_class): import tensorflow as tf From 08909ed92f325c2386d7cd7e0fce7010e0940dc6 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 26 Jun 2023 20:27:48 +0000 Subject: [PATCH 5/8] Fix the test for the FunctionalModel --- keras_core/trainers/trainer_test.py | 59 +++++++++++++++++++---------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 8ac7b7cfa..9f069ea86 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -451,19 +451,48 @@ def call(self, x): y = np.zeros((16, 1)) model.fit(x, y, batch_size=4) + def get_layer(self): + class ExampleLayer(keras_core.Layer): + def call(self, x): + return x * 2 + + return ExampleLayer + + def get_model(self): + class ExampleModel(keras_core.Model): + def call(self, x): + return x * 2 + + return ExampleModel + + def get_functional(self): + class ExampleFunctional(keras_core.Functional): + def __init__(self, input_shape=(None,)): + inputs = keras_core.Input(input_shape) + # The functional model uses the + # ``tensorflow.experimental.numpy`` API which doesn't yet + # support RaggedTensors. So, most keras_core operations + # won't work when ragged tensors are passed to the Functional + # model. We just test that passing RaggedTensors works for + # now. + outputs = inputs + super().__init__(inputs=inputs, outputs=outputs) + + return ExampleFunctional + @parameterized.named_parameters( [ { - "testcase_name": "base_class_model", - "base_class": keras_core.Model, + "testcase_name": "model", + "model_class": "get_model", }, { - "testcase_name": "base_class_layer", - "base_class": keras_core.Layer, + "testcase_name": "layer", + "model_class": "get_layer", }, { - "testcase_name": "base_class_functional", - "base_class": keras_core.Functional, + "testcase_name": "functional", + "model_class": "get_functional", }, ] ) @@ -471,25 +500,13 @@ def call(self, x): keras_core.backend.backend() != "tensorflow", reason="Only tensorflow supports raggeds", ) - def test_trainer_with_raggeds(self, base_class): + def test_trainer_with_raggeds(self, model_class): import tensorflow as tf - class ExampleModel(base_class): - def __init__(self, input_shape=(None,)): - if base_class is keras_core.Functional: - inputs = keras_core.Input(input_shape) - outputs = inputs * 2 - super().__init__(inputs=inputs, outputs=outputs) - else: - super().__init__() - - def call(self, x): - return 2 * x - def loss_fn(y, y_pred, sample_weight=None): return 0 - model = ExampleModel() + model = getattr(self, model_class)()() x = tf.ragged.constant([[1], [2, 3]]) # test forward pass @@ -497,7 +514,7 @@ def loss_fn(y, y_pred, sample_weight=None): self.assertEqual(type(y), tf.RaggedTensor) # test training - if base_class in [keras_core.Model, keras_core.Functional]: + if model_class in ["get_model", "get_functional"]: model.compile(optimizer="adam", loss=loss_fn) model.fit(x, x) y = model.predict(x) From 801d94f6aad24eb2adfee962dfbe2085c80d937b Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 26 Jun 2023 21:55:24 +0000 Subject: [PATCH 6/8] Use a custom layer in Functional test --- keras_core/trainers/trainer_test.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 9f069ea86..8a711fee9 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -466,16 +466,12 @@ def call(self, x): return ExampleModel def get_functional(self): + ExampleLayer = self.get_layer() + class ExampleFunctional(keras_core.Functional): def __init__(self, input_shape=(None,)): inputs = keras_core.Input(input_shape) - # The functional model uses the - # ``tensorflow.experimental.numpy`` API which doesn't yet - # support RaggedTensors. So, most keras_core operations - # won't work when ragged tensors are passed to the Functional - # model. We just test that passing RaggedTensors works for - # now. - outputs = inputs + outputs = ExampleLayer()(inputs) super().__init__(inputs=inputs, outputs=outputs) return ExampleFunctional From f36c32b6262ab27492380073cdd2d7f1c802f92f Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 26 Jun 2023 21:56:49 +0000 Subject: [PATCH 7/8] Use 4 space indent in docs --- keras_core/backend/tensorflow/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 0e22c6ebb..1ba3c80a6 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -832,11 +832,11 @@ def potentially_ragged_concat(tensors): """Concats `Tensor`s along their first dimension. Args: - tensors: List of `Tensor`s. + tensors: List of `Tensor`s. Returns: - Concatenation of the inputs along the first dimension -- of type `Tensor` - if all input shapes are compatible, or `RaggedTensor` if not. + Concatenation of the inputs along the first dimension -- of type `Tensor` + if all input shapes are compatible, or `RaggedTensor` if not. """ if len(tensors) == 1: return tensors[0] From 85c738eaa2c5a3f6172e98667b081dd40d34c196 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 26 Jun 2023 22:33:26 +0000 Subject: [PATCH 8/8] Fix PEP 501: line too long --- keras_core/backend/tensorflow/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 1ba3c80a6..fe0404b1a 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -835,8 +835,9 @@ def potentially_ragged_concat(tensors): tensors: List of `Tensor`s. Returns: - Concatenation of the inputs along the first dimension -- of type `Tensor` - if all input shapes are compatible, or `RaggedTensor` if not. + Concatenation of the inputs along the first dimension -- of type + `Tensor` if all input shapes are compatible, or `RaggedTensor` + if not. """ if len(tensors) == 1: return tensors[0]