diff --git a/keras/src/saving/file_editor.py b/keras/src/saving/file_editor.py index cdaa1601757..68e21128cc7 100644 --- a/keras/src/saving/file_editor.py +++ b/keras/src/saving/file_editor.py @@ -36,6 +36,10 @@ def is_ipython_notebook(): class KerasFileEditor: """Utility to inspect, edit, and resave Keras weights files. + You will find this class useful when adapting + an old saved weights file after having made + architecture changes to a model. + Args: filepath: The path to a local file to inspect and edit. @@ -43,12 +47,16 @@ class KerasFileEditor: ```python editor = KerasFileEditor("my_model.weights.h5") + # Displays current contents editor.summary() + # Remove the weights of an existing layer editor.delete_object("layers/dense_2") + # Add the weights of a new layer editor.add_object("layers/einsum_dense", weights={"0": ..., "1": ...}) + # Save the weights of the edited model editor.resave_weights("edited_model.weights.h5") ``` @@ -90,6 +98,7 @@ def __init__( weights_dict, object_metadata = self._extract_weights_from_store( weights_store.h5_file ) + weights_store.close() self.weights_dict = weights_dict self.object_metadata = object_metadata # {path: object_name} self.console.print(self._generate_filepath_info(rich_style=True)) @@ -170,7 +179,7 @@ def _compare( f"[color(160)]...Weight shape mismatch " f"for [bold]{inner_path}[/][/]\n" f" In {ref_name}: " - f"shape={tuple(ref_val[0])}\n" + f"shape={ref_val.shape}\n" f" In {target_name}: " f"shape={target[ref_key].shape}" )