Skip to content

Commit

Permalink
Fixes layer index naming issue with new Keras weights saving.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543503892
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Jun 26, 2023
1 parent 5f9d052 commit 5040e9a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
13 changes: 11 additions & 2 deletions keras/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def save_model(model, filepath, weights_format="h5"):
zip_filepath = filepath
try:
with zipfile.ZipFile(zip_filepath, "w") as zf:

with zf.open(_METADATA_FILENAME, "w") as f:
f.write(metadata_json.encode())
with zf.open(_CONFIG_FILENAME, "w") as f:
Expand Down Expand Up @@ -233,7 +232,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
with tf.io.gfile.GFile(
filepath, mode="r+b"
) as gfile_handle, zipfile.ZipFile(gfile_handle, "r") as zf:

with zf.open(_CONFIG_FILENAME, "r") as f:
config_json = f.read()

Expand Down Expand Up @@ -484,6 +482,10 @@ def _save_container_state(

for trackable in container:
if _is_keras_trackable(trackable):
# Keeps layer name indexing in proper order
# when duplicate layers are in container.
if id(trackable) in visited_trackables:
continue
# Do NOT address the trackable via `trackable.name`, since
# names are usually autogenerated and thus not reproducible
# (i.e. they may vary across two instances of the same model).
Expand Down Expand Up @@ -516,6 +518,13 @@ def _load_container_state(

for trackable in container:
if _is_keras_trackable(trackable):
# Keeps layer name indexing in proper order
# when duplicate layers are in container.
if visited_trackables and id(trackable) in visited_trackables:
continue
# Do NOT address the trackable via `trackable.name`, since
# names are usually autogenerated and thus not reproducible
# (i.e. they may vary across two instances of the same model).
name = generic_utils.to_snake_case(trackable.__class__.__name__)
if name in used_names:
used_names[name] += 1
Expand Down
19 changes: 19 additions & 0 deletions keras/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pathlib import Path
from unittest import mock

import h5py
import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized
Expand Down Expand Up @@ -734,6 +735,24 @@ def test_normalization_kpl(self):
out = model(data)
self.assertAllClose(ref_out, out, atol=1e-6)

def test_layer_index_naming(self):
weights_filepath = os.path.join(self.get_temp_dir(), "model.weights.h5")
model = keras.Sequential(
[
keras.layers.Dense(10),
keras.layers.Dense(10),
keras.layers.Dense(10),
keras.layers.Dense(10),
]
)
model.build([1, 20])
model.save_weights(weights_filepath)
with h5py.File(weights_filepath, "r") as f:
self.assertAllEqual(
list(f["_layer_checkpoint_dependencies"].keys()),
["dense", "dense_1", "dense_2", "dense_3"],
)


# This custom class lacks custom object registration.
class CustomRNN(keras.layers.Layer):
Expand Down

0 comments on commit 5040e9a

Please sign in to comment.