Skip to content

Commit

Permalink
Avoid redundancy in file comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 23, 2024
1 parent 36a01f1 commit 0acd575
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
12 changes: 11 additions & 1 deletion keras/src/saving/file_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,18 @@ def _compare(
ref_name,
error_count,
match_count,
checked_paths,
):
base_inner_path = inner_path
for ref_key, ref_val in ref_spec.items():
inner_path = base_inner_path + "/" + ref_key
if inner_path in checked_paths:
continue

if ref_key not in target:
error_count += 1
checked_paths.add(inner_path)
if isinstance(ref_val, dict):
error_count += 1
self.console.print(
f"[color(160)]...Object [bold]{inner_path}[/] "
f"present in {ref_name}, "
Expand All @@ -169,12 +174,14 @@ def _compare(
ref_name,
error_count=error_count,
match_count=match_count,
checked_paths=checked_paths,
)
error_count += _error_count
match_count += _match_count
else:
if target[ref_key].shape != ref_val.shape:
error_count += 1
checked_paths.add(inner_path)
self.console.print(
f"[color(160)]...Weight shape mismatch "
f"for [bold]{inner_path}[/][/]\n"
Expand All @@ -187,6 +194,7 @@ def _compare(
match_count += 1
return error_count, match_count

checked_paths = set()
error_count, match_count = _compare(
self.weights_dict,
ref_spec,
Expand All @@ -195,6 +203,7 @@ def _compare(
ref_name="reference model",
error_count=0,
match_count=0,
checked_paths=checked_paths,
)
_error_count, _ = _compare(
ref_spec,
Expand All @@ -204,6 +213,7 @@ def _compare(
ref_name="saved file",
error_count=0,
match_count=0,
checked_paths=checked_paths,
)
error_count += _error_count
self.console.print("─────────────────────")
Expand Down
15 changes: 15 additions & 0 deletions keras/src/saving/file_editor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,18 @@ def test_basics(self):
editor.summary()
out = editor.compare_to(target_model) # Succeeds
self.assertEqual(out["status"], "success")

editor.delete_weight("dense_2", "1")
out = editor.compare_to(target_model) # Fails
self.assertEqual(out["status"], "error")
self.assertEqual(out["error_count"], 1)

editor.add_weights("dense_2", {"1": np.zeros((7,))})
out = editor.compare_to(target_model) # Fails
self.assertEqual(out["status"], "error")
self.assertEqual(out["error_count"], 1)

editor.delete_weight("dense_2", "1")
editor.add_weights("dense_2", {"1": np.zeros((3,))})
out = editor.compare_to(target_model) # Succeeds
self.assertEqual(out["status"], "success")

0 comments on commit 0acd575

Please sign in to comment.