Skip to content

Commit

Permalink
Add benchmark analysis notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 29, 2024
1 parent 288d38d commit 2e356b7
Showing 1 changed file with 213 additions and 0 deletions.
213 changes: 213 additions & 0 deletions tuning/benchmark_analysis.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5c06148d9ff6b57",
"metadata": {
"collapsed": false
},
"source": [
"This notebook loads all the optuna studies in the \"tuning\" folder and arranges them in a dataframe. It also loads the performance of the best model from the paper and the rerun results.\n",
"\n",
"It can serve as a starting point for further analysis."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31e6f532-15c3-494a-8a3a-de25ecc1ee90",
"metadata": {},
"outputs": [],
"source": [
"# Load all the studies into a dataframe\n",
"\n",
"import optuna\n",
"from collections import Counter\n",
"from optuna.trial import TrialState\n",
"import pandas as pd\n",
"import numpy as np\n",
"import datetime\n",
"from pathlib import Path\n",
"\n",
"import imitation.util.sacred_file_parsing as sfp\n",
"\n",
"\n",
"experiment_log_files = list(Path().glob(\"*/*.log\"))\n",
"\n",
"experiment_log_files\n",
"\n",
"raw_study_data = []\n",
"\n",
"for log_file in experiment_log_files:\n",
" d = dict()\n",
" \n",
" d['logfile'] = log_file\n",
" \n",
" study = optuna.load_study(storage=optuna.storages.JournalStorage(\n",
" optuna.storages.JournalFileStorage(str(log_file))\n",
" ),\n",
" # in our case, we have one journal file per study so the study name can be\n",
" # inferred\n",
" study_name=None,\n",
" )\n",
" d['study'] = study\n",
" d['study_name'] = study.study_name\n",
" \n",
" trial_state_counter = Counter(t.state for t in study.trials)\n",
" n_completed_trials = trial_state_counter[TrialState.COMPLETE]\n",
" d['trials'] = n_completed_trials\n",
" d['trials_running'] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
" d['trials_failed'] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
" d['all_trials'] = len(study.trials)\n",
" \n",
" if n_completed_trials > 0:\n",
" d['best_value'] = round(study.best_trial.value, 2)\n",
" \n",
" assert \"_\" in study.study_name\n",
" study_segments = study.study_name.split(\"_\") \n",
" assert len(study_segments) > 3\n",
" tuning, algo, with_ = study_segments[:3]\n",
" assert (tuning, with_) == (\"tuning\", \"with\")\n",
" \n",
" d['algo'] = algo\n",
" d['env'] = \"_\".join(study_segments[3:])\n",
" d['best_trial_duration'] = study.best_trial.duration\n",
" d['mean_duration'] = sum([t.duration for t in study.trials if t.state == TrialState.COMPLETE], datetime.timedelta())/n_completed_trials\n",
" \n",
" reruns_folder = log_file.parent / \"reruns\"\n",
" rerun_results = [round(run['result']['imit_stats']['monitor_return_mean'], 2)\n",
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)]\n",
" d['rerun_values'] = rerun_results\n",
" \n",
" raw_study_data.append(d)\n",
" \n",
"study_data = pd.DataFrame(raw_study_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b604bc7e-2e61-4f7f-acfe-87b57e8a2f5a",
"metadata": {},
"outputs": [],
"source": [
"# Add performance of the best model from the paper\n",
"import pandas as pd\n",
"\n",
"environments = [\n",
" \"seals_ant\",\n",
" \"seals_half_cheetah\",\n",
" \"seals_hopper\",\n",
" \"seals_swimmer\",\n",
" \"seals_walker\",\n",
" \"seals_humanoid\",\n",
" \"seals_cartpole\",\n",
" \"pendulum\",\n",
" \"seals_mountain_car\"\n",
"]\n",
"\n",
"pc_paper_700 = dict(\n",
" seals_ant=200,\n",
" seals_half_cheetah=4700,\n",
" seals_hopper=4500,\n",
" seals_swimmer=170,\n",
" seals_walker=4900,\n",
" seals_humanoid=\"-\",\n",
" seals_cartpole=\"-\",\n",
" pendulum=1300,\n",
" seals_mountain_car=\"-\",\n",
")\n",
"\n",
"pc_paper_1400 = dict(\n",
" seals_ant=100,\n",
" seals_half_cheetah=5600,\n",
" seals_hopper=4500,\n",
" seals_swimmer=175,\n",
" seals_walker=5900,\n",
" seals_humanoid=\"-\",\n",
" seals_cartpole=\"-\",\n",
" pendulum=750,\n",
" seals_mountain_car=\"-\",\n",
")\n",
"\n",
"rl_paper = dict(\n",
" seals_ant=16,\n",
" seals_half_cheetah=420,\n",
" seals_hopper=4210,\n",
" seals_swimmer=175,\n",
" seals_walker=5370,\n",
" seals_humanoid=\"-\",\n",
" seals_cartpole=\"-\",\n",
" pendulum=1300,\n",
" seals_mountain_car=\"-\",\n",
")\n",
"\n",
"rl_ours = dict(\n",
" seals_ant=3034,\n",
" seals_half_cheetah=1675.76,\n",
" seals_hopper=203.45,\n",
" seals_swimmer=292.84,\n",
" seals_walker=2465.56,\n",
" seals_humanoid=3224.12,\n",
" seals_cartpole=500.00,\n",
" pendulum=-189.25,\n",
" seals_mountain_car=-97.00,\n",
")\n",
"\n",
"for algo, values_by_env in dict(\n",
" pc_paper_700=pc_paper_700,\n",
" pc_paper_1400=pc_paper_1400,\n",
" rl_paper=rl_paper,\n",
" rl_ours=rl_ours,\n",
").items():\n",
" for env, value in values_by_env.items():\n",
" if value == \"-\":\n",
" continue\n",
" raw_study_data.append(dict(\n",
" algo=algo,\n",
" env=env,\n",
" best_value=value,\n",
" ))\n",
" \n",
"study_data = pd.DataFrame(raw_study_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e9ae5ca-5002-411b-beaf-cb98eb12f54c",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display\n",
"\n",
"print(\"Benchmark Data\")\n",
"display(study_data[[\"algo\", \"env\", \"best_value\"]])\n",
"\n",
"print(\"Rerun Data\")\n",
"display(study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][study_data[\"rerun_values\"].map(np.std) > 0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 2e356b7

Please sign in to comment.