Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder #14: Add example script for how to resume a tune.Tuner.fit() experiment from a checkpoint. #45681

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2120,15 +2120,6 @@ py_test(
# subdirectory: checkpoints/
# ....................................

#@OldAPIStack
py_test(
name = "examples/checkpoints/cartpole_dqn_export",
main = "examples/checkpoints/cartpole_dqn_export.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
srcs = ["examples/checkpoints/cartpole_dqn_export.py"],
)

py_test(
name = "examples/checkpoints/checkpoint_by_custom_criteria",
main = "examples/checkpoints/checkpoint_by_custom_criteria.py",
Expand All @@ -2138,6 +2129,42 @@ py_test(
args = ["--enable-new-api-stack", "--stop-reward=150.0", "--num-cpus=8"]
)

py_test(
name = "examples/checkpoints/continue_training_from_checkpoint",
main = "examples/checkpoints/continue_training_from_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"],
args = ["--enable-new-api-stack", "--as-test"]
)

py_test(
name = "examples/checkpoints/continue_training_from_checkpoint_multi_agent",
main = "examples/checkpoints/continue_training_from_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"],
args = ["--enable-new-api-stack", "--as-test", "--num-agents=2", "--stop-reward-crash=400.0", "--stop-reward=900.0"]
)

#@OldAPIStack
py_test(
name = "examples/checkpoints/continue_training_from_checkpoint_old_api_stack",
main = "examples/checkpoints/continue_training_from_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"],
args = ["--as-test"]
)

py_test(
name = "examples/checkpoints/cartpole_dqn_export",
main = "examples/checkpoints/cartpole_dqn_export.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
srcs = ["examples/checkpoints/cartpole_dqn_export.py"],
)

#@OldAPIStack
py_test(
name = "examples/checkpoints/onnx_tf2",
Expand Down
17 changes: 11 additions & 6 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,22 +630,27 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
)

self.metrics.log_dict(
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}),
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED), default={}
),
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)
self.metrics.log_value(
NUM_ENV_STEPS_SAMPLED_LIFETIME,
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED, default=0),
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED), default=0),
reduce="sum",
)
self.metrics.log_value(
NUM_EPISODES_LIFETIME,
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_EPISODES, default=0),
self.metrics.peek((ENV_RUNNER_RESULTS, NUM_EPISODES), default=0),
reduce="sum",
)
self.metrics.log_dict(
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED, default={}),
self.metrics.peek(
(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED),
default={},
),
key=NUM_MODULE_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)
Expand Down Expand Up @@ -708,7 +713,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
self.metrics.log_value(
NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
Expand All @@ -725,7 +730,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
# TODO (sven): Uncomment this once agent steps are available in the
# Learner stats.
# self.metrics.log_dict(self.metrics.peek(
# LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED, default={}
# (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED), default={}
# ), key=NUM_AGENT_STEPS_TRAINED_LIFETIME, reduce="sum")

# Update replay buffer priorities.
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/dreamerv3/dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,13 @@ def training_step(self) -> ResultDict:
self.metrics.log_dict(
{
NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED)
),
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED)
),
NUM_EPISODES_LIFETIME: self.metrics.peek(
ENV_RUNNER_RESULTS, NUM_EPISODES
(ENV_RUNNER_RESULTS, NUM_EPISODES)
),
},
reduce="sum",
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def compute_gradients(
# Take individual loss term from the registered metrics for
# the main module.
self.metrics.peek(
DEFAULT_MODULE_ID, component.upper() + "_L_total"
(DEFAULT_MODULE_ID, component.upper() + "_L_total")
),
self.filter_param_dict_for_optimizer(
self._params, self.get_optimizer(optimizer_name=component)
Expand Down
4 changes: 1 addition & 3 deletions rllib/algorithms/dreamerv3/utils/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ def report_dreamed_eval_trajectory_vs_samples(
the report/videos.
"""
dream_data = metrics.peek(
LEARNER_RESULTS,
DEFAULT_MODULE_ID,
"dream_data",
(LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data"),
default={},
)
metrics.delete(LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data", key_error=False)
Expand Down
14 changes: 8 additions & 6 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,13 +463,13 @@ def _training_step_new_api_stack(self) -> ResultDict:
self.metrics.log_dict(
{
NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED
(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED)
),
NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED
(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED)
),
NUM_EPISODES_LIFETIME: self.metrics.peek(
ENV_RUNNER_RESULTS, NUM_EPISODES
(ENV_RUNNER_RESULTS, NUM_EPISODES)
),
},
reduce="sum",
Expand All @@ -494,10 +494,10 @@ def _training_step_new_api_stack(self) -> ResultDict:
self.metrics.log_dict(
{
NUM_ENV_STEPS_TRAINED_LIFETIME: self.metrics.peek(
LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
# NUM_MODULE_STEPS_TRAINED_LIFETIME: self.metrics.peek(
# LEARNER_RESULTS, NUM_MODULE_STEPS_TRAINED
# (LEARNER_RESULTS, NUM_MODULE_STEPS_TRAINED)
# ),
},
reduce="sum",
Expand Down Expand Up @@ -531,7 +531,9 @@ def _training_step_new_api_stack(self) -> ResultDict:
if self.config.use_kl_loss:
for mid in modules_to_update:
kl = convert_to_numpy(
self.metrics.peek(LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY)
self.metrics.peek(
(LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY)
)
)
if np.isnan(kl):
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def compute_gradients(
for component in (
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
):
self.metrics.peek(module_id, component + "_loss").backward(
self.metrics.peek((module_id, component + "_loss")).backward(
retain_graph=True
)
grads.update(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Example extracting a checkpoint from n trials using one or more custom criteria.

This example:
- runs a simple CartPole experiment with three different learning rates (three tune
- runs a CartPole experiment with three different learning rates (three tune
"trials"). During the experiment, for each trial, we create a checkpoint at each
iteration.
- at the end of the experiment, we compare the trials and pick the one that performed
Expand Down
Loading
Loading