Skip to content

Commit

Permalink
[RLlib] Add before_sub_environment_reset callback. (ray-project#28566)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Sep 19, 2022
1 parent 3268a90 commit 586c217
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 21 deletions.
18 changes: 15 additions & 3 deletions doc/source/rllib/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,21 @@ Ray actors provide high levels of performance, so in more complex cases they can
Callbacks and Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You can provide callbacks to be called at points during policy evaluation. These callbacks have access to state for the current `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__. Certain callbacks such as ``on_postprocess_trajectory``, ``on_sample_end``, and ``on_train_result`` are also places where custom postprocessing can be applied to intermediate data or results.

User-defined state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__.
You can provide callbacks to be called at points during policy evaluation.
These callbacks have access to state for the current
`episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__.
Certain callbacks such as ``on_postprocess_trajectory``, ``on_sample_end``,
and ``on_train_result`` are also places where custom postprocessing can be applied to
intermediate data or results.

User-defined state can be stored for the
`episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__
in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values
to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and
reported as part of training results. For a full example, take a look at
`this example script here <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__
and
`these unit test cases here <https://github.com/ray-project/ray/blob/master/rllib/algorithms/tests/test_callbacks.py>`__.

.. tip::
You can create custom logic that can run on each evaluation episode by checking if the
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ def framework(

def environment(
self,
*,
env: Optional[Union[str, EnvType]] = None,
*,
env_config: Optional[EnvConfigDict] = None,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
Expand Down
102 changes: 85 additions & 17 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic,
PublicAPI,
)
Expand Down Expand Up @@ -51,13 +52,42 @@ def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
)
self.legacy_callbacks = legacy_callbacks_dict or {}

@OverrideToImplementCustomLogic
def on_algorithm_init(
self,
*,
algorithm: "Algorithm",
**kwargs,
) -> None:
"""Callback run when a new algorithm instance has finished setup.
This method gets called at the end of Algorithm.setup() after all
the initialization is done, and before actually training starts.
Args:
algorithm: Reference to the trainer instance.
kwargs: Forward compatibility placeholder.
"""
pass

@OverrideToImplementCustomLogic
def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None:
"""Callback run whenever a new policy is added to an algorithm.
Args:
policy_id: ID of the newly created policy.
policy: the policy just created.
"""
pass

@OverrideToImplementCustomLogic
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
env_index: Optional[int] = None,
**kwargs,
) -> None:
"""Callback run when a new sub-environment has been created.
Expand All @@ -78,33 +108,30 @@ def on_sub_environment_created(
pass

@OverrideToImplementCustomLogic
def on_algorithm_init(
def before_sub_environment_reset(
self,
*,
algorithm: "Algorithm",
worker: "RolloutWorker",
sub_environment: EnvType,
env_index: int,
**kwargs,
) -> None:
"""Callback run when a new algorithm instance has finished setup.
"""Callback run before a sub-environment is reset.
This method gets called at the end of Algorithm.setup() after all
the initialization is done, and before actually training starts.
This method gets called before every `try_reset()` is called by RLlib
on a sub-environment (usually a gym.Env). This includes the very first (initial)
reset performed on each sub-environment.
Args:
algorithm: Reference to the trainer instance.
worker: Reference to the current rollout worker.
sub_environment: The sub-environment instance that we are about to reset.
This is usually a gym.Env object.
env_index: The index of the sub-environment that is about to be reset
(within the vector of sub-environments of the BaseEnv).
kwargs: Forward compatibility placeholder.
"""
pass

@OverrideToImplementCustomLogic
def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None:
"""Callback run whenever a new policy is added to an algorithm.
Args:
policy_id: ID of the newly created policy.
policy: the policy just created.
"""
pass

@OverrideToImplementCustomLogic
def on_episode_start(
self,
Expand All @@ -113,6 +140,7 @@ def on_episode_start(
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Union[Episode, EpisodeV2],
env_index: Optional[int] = None,
**kwargs,
) -> None:
"""Callback run on the rollout worker before each episode starts.
Expand All @@ -128,6 +156,8 @@ def on_episode_start(
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
env_index: The index of the sub-environment that started the episode
(within the vector of sub-environments of the BaseEnv).
kwargs: Forward compatibility placeholder.
"""

Expand All @@ -148,6 +178,7 @@ def on_episode_step(
base_env: BaseEnv,
policies: Optional[Dict[PolicyID, Policy]] = None,
episode: Union[Episode, EpisodeV2],
env_index: Optional[int] = None,
**kwargs,
) -> None:
"""Runs on each episode step.
Expand All @@ -164,6 +195,8 @@ def on_episode_step(
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
env_index: The index of the sub-environment that stepped the episode
(within the vector of sub-environments of the BaseEnv).
kwargs: Forward compatibility placeholder.
"""

Expand All @@ -180,6 +213,7 @@ def on_episode_end(
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Union[Episode, EpisodeV2, Exception],
env_index: Optional[int] = None,
**kwargs,
) -> None:
"""Runs when an episode is done.
Expand All @@ -200,6 +234,8 @@ def on_episode_end(
that gets thrown from the environment before the episode finishes.
Users of this callback may then handle these error cases properly
with their custom logics.
env_index: The index of the sub-environment that ended the episode
(within the vector of sub-environments of the BaseEnv).
kwargs: Forward compatibility placeholder.
"""

Expand Down Expand Up @@ -403,6 +439,7 @@ def __init__(self):
# Will track the top 10 lines where memory is allocated
tracemalloc.start(10)

@override(DefaultCallbacks)
def on_episode_end(
self,
*,
Expand Down Expand Up @@ -469,21 +506,24 @@ def __call__(self, *args, **kwargs):
def on_trainer_init(self, *args, **kwargs):
raise DeprecationWarning

@override(DefaultCallbacks)
def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
for callback in self._callback_list:
callback.on_algorithm_init(algorithm=algorithm, **kwargs)

@OverrideToImplementCustomLogic
@override(DefaultCallbacks)
def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None:
for callback in self._callback_list:
callback.on_create_policy(policy_id=policy_id, policy=policy)

@override(DefaultCallbacks)
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
Expand All @@ -494,6 +534,24 @@ def on_sub_environment_created(
**kwargs,
)

@override(DefaultCallbacks)
def before_sub_environment_reset(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_index: Optional[int] = None,
**kwargs,
) -> None:
for callback in self._callback_list:
callback.before_sub_environment_reset(
worker=worker,
sub_environment=sub_environment,
env_index=env_index,
**kwargs,
)

@override(DefaultCallbacks)
def on_episode_start(
self,
*,
Expand All @@ -514,6 +572,7 @@ def on_episode_start(
**kwargs,
)

@override(DefaultCallbacks)
def on_episode_step(
self,
*,
Expand All @@ -534,6 +593,7 @@ def on_episode_step(
**kwargs,
)

@override(DefaultCallbacks)
def on_episode_end(
self,
*,
Expand All @@ -554,6 +614,7 @@ def on_episode_end(
**kwargs,
)

@override(DefaultCallbacks)
def on_evaluate_start(
self,
*,
Expand All @@ -566,6 +627,7 @@ def on_evaluate_start(
**kwargs,
)

@override(DefaultCallbacks)
def on_evaluate_end(
self,
*,
Expand All @@ -580,6 +642,7 @@ def on_evaluate_end(
**kwargs,
)

@override(DefaultCallbacks)
def on_postprocess_trajectory(
self,
*,
Expand All @@ -604,12 +667,14 @@ def on_postprocess_trajectory(
**kwargs,
)

@override(DefaultCallbacks)
def on_sample_end(
self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs
) -> None:
for callback in self._callback_list:
callback.on_sample_end(worker=worker, samples=samples, **kwargs)

@override(DefaultCallbacks)
def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
Expand All @@ -618,6 +683,7 @@ def on_learn_on_batch(
policy=policy, train_batch=train_batch, result=result, **kwargs
)

@override(DefaultCallbacks)
def on_train_result(
self, *, algorithm=None, result: dict, trainer=None, **kwargs
) -> None:
Expand Down Expand Up @@ -656,6 +722,7 @@ def __init__(
self._rms = _MovingMeanStd()
super().__init__(*args, **kwargs)

@override(DefaultCallbacks)
def on_learn_on_batch(
self,
*,
Expand Down Expand Up @@ -687,6 +754,7 @@ def on_learn_on_batch(
train_batch[Postprocessing.VALUE_TARGETS] + states_entropy
)

@override(DefaultCallbacks)
def on_train_result(
self, *, result: dict, algorithm=None, trainer=None, **kwargs
) -> None:
Expand Down
48 changes: 48 additions & 0 deletions rllib/algorithms/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks, MultiCallbacks
import ray.rllib.algorithms.dqn as dqn
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.utils.test_utils import framework_iterator


Expand All @@ -22,6 +23,18 @@ def on_sub_environment_created(
)


class BeforeSubEnvironmentResetCallback(DefaultCallbacks):
def __init__(self):
super().__init__()
self._reset_counter = 0

def before_sub_environment_reset(
self, *, worker, sub_environment, env_index, **kwargs
):
print(f"Sub-env {env_index} is going to be reset.")
self._reset_counter += 1


class TestCallbacks(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -102,6 +115,41 @@ def test_on_sub_environment_created_with_remote_envs(self):
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
algo.stop()

def test_before_sub_environment_reset(self):
# 1000 steps sampled (2.5 episodes on each sub-environment) before training
# starts.
config = (
dqn.DQNConfig()
.environment(
RandomEnv,
env_config={
"max_episode_len": 200,
"p_done": 0.0,
},
)
.rollouts(num_envs_per_worker=2, num_rollout_workers=1)
.callbacks(BeforeSubEnvironmentResetCallback)
)

for _ in framework_iterator(config, frameworks=("tf", "torch")):
algo = config.build()
algo.train()
# Two sub-environments share 1000 steps in the first training iteration
# (min_sample_timesteps_per_iteration = 1000).
# -> 1000 / 2 [sub-envs] = 500 [per sub-env]
# -> 1 episode = 200 timesteps
# -> 2.5 episodes per sub-env
# -> 3 resets [per sub-env] = 6 resets total
self.assertTrue(
6
== ray.get(
algo.workers.remote_workers()[0].apply.remote(
lambda w: w.callbacks._reset_counter
)
)
)
algo.stop()


if __name__ == "__main__":
import pytest
Expand Down
Loading

0 comments on commit 586c217

Please sign in to comment.