Skip to content

Commit

Permalink
updated unify.update such that duplicate logs are not repeated, and o…
Browse files Browse the repository at this point in the history
…nly the diff against the highest performing log is shown.
  • Loading branch information
djl11 committed Nov 1, 2024
1 parent b3f976c commit 80908b1
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions unify/evals/assist/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,27 @@ def _print_table_from_dicts(dcts, col_list=None) -> str:
def _get_evals(logs: List[unify.Log], metric: str) -> str:

evals = {
k: v
k: [lg.to_json()["entries"] for lg in v]
for k, v in sorted(
unify.group_logs_by_configs(logs=logs).items(),
key=lambda item: sum([lg.entries[metric] for lg in item[1]]) / len(item[1]),
)
}

observed_entries = set()
evals_pruned = dict()
for config_str, entries in reversed(evals.items()):
evals_pruned[config_str] = list()
for entry in entries:
entry_str = json.dumps(entry)
if entry_str in observed_entries:
continue
observed_entries.add(entry_str)
evals_pruned[config_str].append(entry)
evals_pruned = {k: v for k, v in reversed(evals_pruned.items())}

ret = list()
for i, (config_str, logs) in enumerate(evals.items()):
for i, (config_str, entries) in enumerate(evals_pruned.items()):
ret.append(
f"Experiment {i}:\n" + len(str(i)) * "=" + "============\n",
)
Expand All @@ -48,7 +60,6 @@ def _get_evals(logs: List[unify.Log], metric: str) -> str:
ret.append(
"\n" " Logs:\n" " -----",
)
entries = [lg.to_json()["entries"] for lg in logs]
ret.append(
" " * 4
+ json.dumps(
Expand Down

0 comments on commit 80908b1

Please sign in to comment.