Skip to content

Commit

Permalink
Fix the test for the FunctionalModel
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel committed Jun 26, 2023
1 parent bdd765c commit a62bdc9
Showing 1 changed file with 38 additions and 21 deletions.
59 changes: 38 additions & 21 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,53 +451,70 @@ def call(self, x):
y = np.zeros((16, 1))
model.fit(x, y, batch_size=4)

def get_layer(self):
class ExampleLayer(keras_core.Layer):
def call(self, x):
return x * 2

return ExampleLayer

def get_model(self):
class ExampleModel(keras_core.Model):
def call(self, x):
return x * 2

return ExampleModel

def get_functional(self):
class ExampleFunctional(keras_core.Functional):
def __init__(self, input_shape=(None,)):
inputs = keras_core.Input(input_shape)
# The functional model uses the
# ``tensorflow.experimental.numpy`` API which doesn't yet
# support RaggedTensors. So, most keras_core operations
# won't work when ragged tensors are passed to the Functional
# model. We just test that passing RaggedTensors works for
# now.
outputs = inputs
super().__init__(inputs=inputs, outputs=outputs)

return ExampleFunctional

@parameterized.named_parameters(
[
{
"testcase_name": "base_class_model",
"base_class": keras_core.Model,
"testcase_name": "model",
"model_class": "get_model",
},
{
"testcase_name": "base_class_layer",
"base_class": keras_core.Layer,
"testcase_name": "layer",
"model_class": "get_layer",
},
{
"testcase_name": "base_class_functional",
"base_class": keras_core.Functional,
"testcase_name": "functional",
"model_class": "get_functional",
},
]
)
@pytest.mark.skipif(
keras_core.backend.backend() != "tensorflow",
reason="Only tensorflow supports raggeds",
)
def test_trainer_with_raggeds(self, base_class):
def test_trainer_with_raggeds(self, model_class):
import tensorflow as tf

class ExampleModel(base_class):
def __init__(self, input_shape=(None,)):
if base_class is keras_core.Functional:
inputs = keras_core.Input(input_shape)
outputs = inputs * 2
super().__init__(inputs=inputs, outputs=outputs)
else:
super().__init__()

def call(self, x):
return 2 * x

def loss_fn(y, y_pred, sample_weight=None):
return 0

model = ExampleModel()
model = getattr(self, model_class)()()
x = tf.ragged.constant([[1], [2, 3]])

# test forward pass
y = model(x)
self.assertEqual(type(y), tf.RaggedTensor)

# test training
if base_class in [keras_core.Model, keras_core.Functional]:
if model_class in ["get_model", "get_functional"]:
model.compile(optimizer="adam", loss=loss_fn)
model.fit(x, x)
y = model.predict(x)
Expand Down

0 comments on commit a62bdc9

Please sign in to comment.