Skip to content

Commit

Permalink
feat: pass in xlabel as an argument to sample efficiency plots
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Dec 1, 2023
1 parent dc22136 commit 217a733
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
39 changes: 23 additions & 16 deletions examples/quickstart.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions marl_eval/plotting_tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def sample_efficiency_curves(
dictionary: Dict[str, Dict[str, Any]],
metric_name: str,
metrics_to_normalize: List[str],
xlabel: str = "Timesteps",
) -> Tuple[Figure, Dict[str, np.ndarray], Dict[str, np.ndarray]]:
"""Produces sample efficiency curve plots.
Expand All @@ -284,6 +285,7 @@ def sample_efficiency_curves(
metric scores for metric algorithm pairs.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
xlabel: Label for x-axis.
Returns:
fig: Matplotlib figure for storing.
Expand Down Expand Up @@ -335,7 +337,7 @@ def sample_efficiency_curves(
iqm_scores,
iqm_cis,
algorithms=algorithms,
xlabel=r"Number of timesteps (Millions)",
xlabel=xlabel,
ylabel=ylabel,
legend=algorithms,
figsize=(15, 8),
Expand All @@ -353,6 +355,7 @@ def plot_single_task(
task_name: str,
metric_name: str,
metrics_to_normalize: List[str],
xlabel: str = "Timesteps",
) -> Figure:
"""Produces aggregated plot for a single task in an environment.
Expand All @@ -362,6 +365,7 @@ def plot_single_task(
task_name: Name of task to produce plots for.
metric_name: Name of metric to produce plots for.
metrics_to_normalize: List of metrics that are normalised.
xlabel: Label for x-axis.
"""
metric_name, task_name, environment_name, metrics_to_normalize = lower_case_inputs(
metric_name, task_name, environment_name, metrics_to_normalize
Expand Down Expand Up @@ -392,7 +396,7 @@ def plot_single_task(
fig = plot_single_task_curve(
task_mean_ci_data,
algorithms=algorithms,
xlabel="Number of timesteps (Millions)",
xlabel=xlabel,
ylabel=ylabel,
legend=algorithms,
figsize=(15, 8),
Expand Down

0 comments on commit 217a733

Please sign in to comment.