From 6db98722b9c210560cad338f74868e480cc18d8b Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:15:34 -0700 Subject: [PATCH] Add associated test --- keras/optimizers/legacy/optimizer_v2_test.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/keras/optimizers/legacy/optimizer_v2_test.py b/keras/optimizers/legacy/optimizer_v2_test.py index 47ffec24453..2d6bbc14e1d 100644 --- a/keras/optimizers/legacy/optimizer_v2_test.py +++ b/keras/optimizers/legacy/optimizer_v2_test.py @@ -15,6 +15,7 @@ """Functional test for OptimizerV2.""" import collections +import os from copy import deepcopy import numpy as np @@ -568,6 +569,33 @@ def testOptimizerWithKerasModel(self): batch_size=5, ) + @test_combinations.generate(test_combinations.combine(mode=["eager"])) + def testOptimizerSaving(self): + np.random.seed(1331) + input_np = np.random.random((10, 3)) + output_np = np.random.random((10, 4)) + a = input_layer.Input(shape=(3,), name="input_a") + model = sequential.Sequential() + model.add(core.Dense(4, kernel_initializer="zeros", name="dense")) + model.add(regularization.Dropout(0.5, name="dropout")) + model(a) + optimizer = gradient_descent.SGD(learning_rate=0.1) + model.compile(optimizer, loss="mse", metrics=["mae"]) + + model.fit( + input_np, + output_np, + batch_size=10, + validation_data=(input_np, output_np), + epochs=2, + verbose=0, + ) + + temp_filepath = os.path.join(self.get_temp_dir(), "optv2_model.keras") + model.save(temp_filepath) + loaded_model = keras.models.load_model(temp_filepath) + self.assertAllClose(model(input_np), loaded_model(input_np), atol=1e-6) + @test_combinations.generate( test_combinations.combine(mode=["graph", "eager"]) )