From a62bdc97183b6b7c59918d5c71a9b88b8e051440 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 26 Jun 2023 20:27:48 +0000 Subject: [PATCH] Fix the test for the FunctionalModel --- keras_core/trainers/trainer_test.py | 59 +++++++++++++++++++---------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/keras_core/trainers/trainer_test.py b/keras_core/trainers/trainer_test.py index 8ac7b7cfa..9f069ea86 100644 --- a/keras_core/trainers/trainer_test.py +++ b/keras_core/trainers/trainer_test.py @@ -451,19 +451,48 @@ def call(self, x): y = np.zeros((16, 1)) model.fit(x, y, batch_size=4) + def get_layer(self): + class ExampleLayer(keras_core.Layer): + def call(self, x): + return x * 2 + + return ExampleLayer + + def get_model(self): + class ExampleModel(keras_core.Model): + def call(self, x): + return x * 2 + + return ExampleModel + + def get_functional(self): + class ExampleFunctional(keras_core.Functional): + def __init__(self, input_shape=(None,)): + inputs = keras_core.Input(input_shape) + # The functional model uses the + # ``tensorflow.experimental.numpy`` API which doesn't yet + # support RaggedTensors. So, most keras_core operations + # won't work when ragged tensors are passed to the Functional + # model. We just test that passing RaggedTensors works for + # now. + outputs = inputs + super().__init__(inputs=inputs, outputs=outputs) + + return ExampleFunctional + @parameterized.named_parameters( [ { - "testcase_name": "base_class_model", - "base_class": keras_core.Model, + "testcase_name": "model", + "model_class": "get_model", }, { - "testcase_name": "base_class_layer", - "base_class": keras_core.Layer, + "testcase_name": "layer", + "model_class": "get_layer", }, { - "testcase_name": "base_class_functional", - "base_class": keras_core.Functional, + "testcase_name": "functional", + "model_class": "get_functional", }, ] ) @@ -471,25 +500,13 @@ def call(self, x): keras_core.backend.backend() != "tensorflow", reason="Only tensorflow supports raggeds", ) - def test_trainer_with_raggeds(self, base_class): + def test_trainer_with_raggeds(self, model_class): import tensorflow as tf - class ExampleModel(base_class): - def __init__(self, input_shape=(None,)): - if base_class is keras_core.Functional: - inputs = keras_core.Input(input_shape) - outputs = inputs * 2 - super().__init__(inputs=inputs, outputs=outputs) - else: - super().__init__() - - def call(self, x): - return 2 * x - def loss_fn(y, y_pred, sample_weight=None): return 0 - model = ExampleModel() + model = getattr(self, model_class)()() x = tf.ragged.constant([[1], [2, 3]]) # test forward pass @@ -497,7 +514,7 @@ def loss_fn(y, y_pred, sample_weight=None): self.assertEqual(type(y), tf.RaggedTensor) # test training - if base_class in [keras_core.Model, keras_core.Functional]: + if model_class in ["get_model", "get_functional"]: model.compile(optimizer="adam", loss=loss_fn) model.fit(x, x) y = model.predict(x)