Skip to content

Commit

Permalink
Adds Keras v3 saving testing coverage to Keras layers tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 527921888
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Apr 28, 2023
1 parent e7c4d09 commit 1b7c53d
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 143 deletions.
21 changes: 18 additions & 3 deletions keras/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized

import keras
from keras.saving import object_registration
from keras.testing_infra import test_combinations
from keras.testing_infra import test_utils

Expand Down Expand Up @@ -515,6 +516,7 @@ def test_initializer(self):
self.assertEqual(output.shape.as_list(), [None, 40, 80])


@object_registration.register_keras_serializable()
class TestModel(keras.Model):
def __init__(self):
super().__init__()
Expand All @@ -540,12 +542,19 @@ def call(self, x, training=False):

@test_combinations.run_all_keras_modes(always_skip_v1=True)
class KerasModelSavingTest(test_combinations.TestCase):
def test_keras_saving_subclass(self):
@parameterized.parameters("tf", "keras_v3")
def test_keras_saving_subclass(self, save_format):
model = TestModel()
query = keras.Input(shape=(40, 80))
_ = model(query)
model_path = self.get_temp_dir() + "/tmp_model"
keras.models.save_model(model, model_path, save_format="tf")
if save_format == "keras_v3":
if not tf.__internal__.tf2.enabled():
self.skipTest(
"TF2 must be enabled to use the new `.keras` saving."
)
model_path += ".keras"
keras.models.save_model(model, model_path, save_format=save_format)
reloaded_model = keras.models.load_model(model_path)
self.assertEqual(
len(model.trainable_variables),
Expand All @@ -556,7 +565,7 @@ def test_keras_saving_subclass(self):
):
self.assertAllEqual(src_v, loaded_v)

@parameterized.parameters("h5", "tf")
@parameterized.parameters("h5", "tf", "keras_v3")
def test_keras_saving_functional(self, save_format):
model = TestModel()
query = keras.Input(shape=(40, 80))
Expand All @@ -565,6 +574,12 @@ def test_keras_saving_functional(self, save_format):
)(query, query)
model = keras.Model(inputs=query, outputs=output)
model_path = self.get_temp_dir() + "/tmp_model"
if save_format == "keras_v3":
if not tf.__internal__.tf2.enabled():
self.skipTest(
"TF2 must be enabled to use the new `.keras` saving."
)
model_path += ".keras"
keras.models.save_model(model, model_path, save_format=save_format)
reloaded_model = keras.models.load_model(model_path)
self.assertEqual(
Expand Down
25 changes: 20 additions & 5 deletions keras/layers/normalization/spectral_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,27 @@ def test_save_load_model(self):
# initialize model
model.predict(tf.random.uniform((2, 1)))

model.save("test.h5")
new_model = keras.models.load_model("test.h5")
with self.subTest("h5"):
model.save("test.h5")
new_model = keras.models.load_model("test.h5")

self.assertEqual(
model.layers[0].get_config(), new_model.layers[0].get_config()
)
self.assertEqual(
model.layers[0].get_config(), new_model.layers[0].get_config()
)
with self.subTest("savedmodel"):
model.save("test")
new_model = keras.models.load_model("test")

self.assertEqual(
model.layers[0].get_config(), new_model.layers[0].get_config()
)
with self.subTest("keras_v3"):
model.save("test.keras")
new_model = keras.models.load_model("test.keras")

self.assertEqual(
model.layers[0].get_config(), new_model.layers[0].get_config()
)

@test_combinations.run_all_keras_modes
def test_normalization(self):
Expand Down
44 changes: 33 additions & 11 deletions keras/layers/preprocessing/hashed_crossing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_from_config(self):
tf.sparse.to_dense(original_outputs),
)

def test_saved_model_keras(self):
def test_saving_keras(self):
string_in = keras.Input(shape=(1,), dtype=tf.string)
int_in = keras.Input(shape=(1,), dtype=tf.int64)
out = hashed_crossing.HashedCrossing(num_bins=10)((string_in, int_in))
Expand All @@ -167,17 +167,39 @@ def test_saved_model_keras(self):
output_data = model((string_data, int_data))
self.assertAllClose(output_data, expected_output)

# Save the model to disk.
output_path = os.path.join(self.get_temp_dir(), "saved_model")
model.save(output_path, save_format="tf")
loaded_model = keras.models.load_model(
output_path,
custom_objects={"HashedCrossing": hashed_crossing.HashedCrossing},
)
with self.subTest("savedmodel"):
# Save the model to disk.
output_path = os.path.join(self.get_temp_dir(), "saved_model")
model.save(output_path, save_format="tf")
loaded_model = keras.models.load_model(
output_path,
custom_objects={
"HashedCrossing": hashed_crossing.HashedCrossing
},
)

# Validate correctness of the new model.
new_output_data = loaded_model((string_data, int_data))
self.assertAllClose(new_output_data, expected_output)

with self.subTest("keras_v3"):
if not tf.__internal__.tf2.enabled():
self.skipTest(
"TF2 must be enabled to use the new `.keras` saving."
)
# Save the model to disk.
output_path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(output_path, save_format="keras_v3")
loaded_model = keras.models.load_model(
output_path,
custom_objects={
"HashedCrossing": hashed_crossing.HashedCrossing
},
)

# Validate correctness of the new model.
new_output_data = loaded_model((string_data, int_data))
self.assertAllClose(new_output_data, expected_output)
# Validate correctness of the new model.
new_output_data = loaded_model((string_data, int_data))
self.assertAllClose(new_output_data, expected_output)


if __name__ == "__main__":
Expand Down
24 changes: 24 additions & 0 deletions keras/layers/preprocessing/hashing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,30 @@ def test_saved_model(self):
new_output_data = loaded_model(input_data)
self.assertAllClose(new_output_data, original_output_data)

@test_utils.run_v2_only
def test_save_keras_v3(self):
input_data = np.array(
["omar", "stringer", "marlo", "wire", "skywalker"]
)

inputs = keras.Input(shape=(None,), dtype=tf.string)
outputs = hashing.Hashing(num_bins=100)(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)

original_output_data = model(input_data)

# Save the model to disk.
output_path = os.path.join(self.get_temp_dir(), "tf_keras_model.keras")
model.save(output_path, save_format="keras_v3")
loaded_model = keras.models.load_model(output_path)

# Ensure that the loaded model is unique (so that the save/load is real)
self.assertIsNot(model, loaded_model)

# Validate correctness of the new model.
new_output_data = loaded_model(input_data)
self.assertAllClose(new_output_data, original_output_data)

@parameterized.named_parameters(
(
"list_input",
Expand Down
Loading

0 comments on commit 1b7c53d

Please sign in to comment.