diff --git a/doc/source/rllib/rllib-training.rst b/doc/source/rllib/rllib-training.rst index d3b794112956..2a22b9eecff8 100644 --- a/doc/source/rllib/rllib-training.rst +++ b/doc/source/rllib/rllib-training.rst @@ -1013,9 +1013,21 @@ Ray actors provide high levels of performance, so in more complex cases they can Callbacks and Custom Metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can provide callbacks to be called at points during policy evaluation. These callbacks have access to state for the current `episode `__. Certain callbacks such as ``on_postprocess_trajectory``, ``on_sample_end``, and ``on_train_result`` are also places where custom postprocessing can be applied to intermediate data or results. - -User-defined state can be stored for the `episode `__ in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py `__. +You can provide callbacks to be called at points during policy evaluation. +These callbacks have access to state for the current +`episode `__. +Certain callbacks such as ``on_postprocess_trajectory``, ``on_sample_end``, +and ``on_train_result`` are also places where custom postprocessing can be applied to +intermediate data or results. + +User-defined state can be stored for the +`episode `__ +in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values +to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and +reported as part of training results. For a full example, take a look at +`this example script here `__ +and +`these unit test cases here `__. .. tip:: You can create custom logic that can run on each evaluation episode by checking if the diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index fd611cd24870..8b382bb2ad27 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -450,8 +450,8 @@ def framework( def environment( self, - *, env: Optional[Union[str, EnvType]] = None, + *, env_config: Optional[EnvConfigDict] = None, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, diff --git a/rllib/algorithms/callbacks.py b/rllib/algorithms/callbacks.py index c2c7ba2fb47c..7038d2e6b8fa 100644 --- a/rllib/algorithms/callbacks.py +++ b/rllib/algorithms/callbacks.py @@ -12,6 +12,7 @@ from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import ( + override, OverrideToImplementCustomLogic, PublicAPI, ) @@ -51,6 +52,34 @@ def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None): ) self.legacy_callbacks = legacy_callbacks_dict or {} + @OverrideToImplementCustomLogic + def on_algorithm_init( + self, + *, + algorithm: "Algorithm", + **kwargs, + ) -> None: + """Callback run when a new algorithm instance has finished setup. + + This method gets called at the end of Algorithm.setup() after all + the initialization is done, and before actually training starts. + + Args: + algorithm: Reference to the trainer instance. + kwargs: Forward compatibility placeholder. + """ + pass + + @OverrideToImplementCustomLogic + def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None: + """Callback run whenever a new policy is added to an algorithm. + + Args: + policy_id: ID of the newly created policy. + policy: the policy just created. + """ + pass + @OverrideToImplementCustomLogic def on_sub_environment_created( self, @@ -58,6 +87,7 @@ def on_sub_environment_created( worker: "RolloutWorker", sub_environment: EnvType, env_context: EnvContext, + env_index: Optional[int] = None, **kwargs, ) -> None: """Callback run when a new sub-environment has been created. @@ -78,33 +108,30 @@ def on_sub_environment_created( pass @OverrideToImplementCustomLogic - def on_algorithm_init( + def before_sub_environment_reset( self, *, - algorithm: "Algorithm", + worker: "RolloutWorker", + sub_environment: EnvType, + env_index: int, **kwargs, ) -> None: - """Callback run when a new algorithm instance has finished setup. + """Callback run before a sub-environment is reset. - This method gets called at the end of Algorithm.setup() after all - the initialization is done, and before actually training starts. + This method gets called before every `try_reset()` is called by RLlib + on a sub-environment (usually a gym.Env). This includes the very first (initial) + reset performed on each sub-environment. Args: - algorithm: Reference to the trainer instance. + worker: Reference to the current rollout worker. + sub_environment: The sub-environment instance that we are about to reset. + This is usually a gym.Env object. + env_index: The index of the sub-environment that is about to be reset + (within the vector of sub-environments of the BaseEnv). kwargs: Forward compatibility placeholder. """ pass - @OverrideToImplementCustomLogic - def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None: - """Callback run whenever a new policy is added to an algorithm. - - Args: - policy_id: ID of the newly created policy. - policy: the policy just created. - """ - pass - @OverrideToImplementCustomLogic def on_episode_start( self, @@ -113,6 +140,7 @@ def on_episode_start( base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Union[Episode, EpisodeV2], + env_index: Optional[int] = None, **kwargs, ) -> None: """Callback run on the rollout worker before each episode starts. @@ -128,6 +156,8 @@ def on_episode_start( state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. + env_index: The index of the sub-environment that started the episode + (within the vector of sub-environments of the BaseEnv). kwargs: Forward compatibility placeholder. """ @@ -148,6 +178,7 @@ def on_episode_step( base_env: BaseEnv, policies: Optional[Dict[PolicyID, Policy]] = None, episode: Union[Episode, EpisodeV2], + env_index: Optional[int] = None, **kwargs, ) -> None: """Runs on each episode step. @@ -164,6 +195,8 @@ def on_episode_step( state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. + env_index: The index of the sub-environment that stepped the episode + (within the vector of sub-environments of the BaseEnv). kwargs: Forward compatibility placeholder. """ @@ -180,6 +213,7 @@ def on_episode_end( base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Union[Episode, EpisodeV2, Exception], + env_index: Optional[int] = None, **kwargs, ) -> None: """Runs when an episode is done. @@ -200,6 +234,8 @@ def on_episode_end( that gets thrown from the environment before the episode finishes. Users of this callback may then handle these error cases properly with their custom logics. + env_index: The index of the sub-environment that ended the episode + (within the vector of sub-environments of the BaseEnv). kwargs: Forward compatibility placeholder. """ @@ -403,6 +439,7 @@ def __init__(self): # Will track the top 10 lines where memory is allocated tracemalloc.start(10) + @override(DefaultCallbacks) def on_episode_end( self, *, @@ -469,21 +506,24 @@ def __call__(self, *args, **kwargs): def on_trainer_init(self, *args, **kwargs): raise DeprecationWarning + @override(DefaultCallbacks) def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None: for callback in self._callback_list: callback.on_algorithm_init(algorithm=algorithm, **kwargs) - @OverrideToImplementCustomLogic + @override(DefaultCallbacks) def on_create_policy(self, *, policy_id: PolicyID, policy: Policy) -> None: for callback in self._callback_list: callback.on_create_policy(policy_id=policy_id, policy=policy) + @override(DefaultCallbacks) def on_sub_environment_created( self, *, worker: "RolloutWorker", sub_environment: EnvType, env_context: EnvContext, + env_index: Optional[int] = None, **kwargs, ) -> None: for callback in self._callback_list: @@ -494,6 +534,24 @@ def on_sub_environment_created( **kwargs, ) + @override(DefaultCallbacks) + def before_sub_environment_reset( + self, + *, + worker: "RolloutWorker", + sub_environment: EnvType, + env_index: Optional[int] = None, + **kwargs, + ) -> None: + for callback in self._callback_list: + callback.before_sub_environment_reset( + worker=worker, + sub_environment=sub_environment, + env_index=env_index, + **kwargs, + ) + + @override(DefaultCallbacks) def on_episode_start( self, *, @@ -514,6 +572,7 @@ def on_episode_start( **kwargs, ) + @override(DefaultCallbacks) def on_episode_step( self, *, @@ -534,6 +593,7 @@ def on_episode_step( **kwargs, ) + @override(DefaultCallbacks) def on_episode_end( self, *, @@ -554,6 +614,7 @@ def on_episode_end( **kwargs, ) + @override(DefaultCallbacks) def on_evaluate_start( self, *, @@ -566,6 +627,7 @@ def on_evaluate_start( **kwargs, ) + @override(DefaultCallbacks) def on_evaluate_end( self, *, @@ -580,6 +642,7 @@ def on_evaluate_end( **kwargs, ) + @override(DefaultCallbacks) def on_postprocess_trajectory( self, *, @@ -604,12 +667,14 @@ def on_postprocess_trajectory( **kwargs, ) + @override(DefaultCallbacks) def on_sample_end( self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs ) -> None: for callback in self._callback_list: callback.on_sample_end(worker=worker, samples=samples, **kwargs) + @override(DefaultCallbacks) def on_learn_on_batch( self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs ) -> None: @@ -618,6 +683,7 @@ def on_learn_on_batch( policy=policy, train_batch=train_batch, result=result, **kwargs ) + @override(DefaultCallbacks) def on_train_result( self, *, algorithm=None, result: dict, trainer=None, **kwargs ) -> None: @@ -656,6 +722,7 @@ def __init__( self._rms = _MovingMeanStd() super().__init__(*args, **kwargs) + @override(DefaultCallbacks) def on_learn_on_batch( self, *, @@ -687,6 +754,7 @@ def on_learn_on_batch( train_batch[Postprocessing.VALUE_TARGETS] + states_entropy ) + @override(DefaultCallbacks) def on_train_result( self, *, result: dict, algorithm=None, trainer=None, **kwargs ) -> None: diff --git a/rllib/algorithms/tests/test_callbacks.py b/rllib/algorithms/tests/test_callbacks.py index ddbbb90bb95d..5d9ec6ac20f8 100644 --- a/rllib/algorithms/tests/test_callbacks.py +++ b/rllib/algorithms/tests/test_callbacks.py @@ -3,6 +3,7 @@ import ray from ray.rllib.algorithms.callbacks import DefaultCallbacks, MultiCallbacks import ray.rllib.algorithms.dqn as dqn +from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.utils.test_utils import framework_iterator @@ -22,6 +23,18 @@ def on_sub_environment_created( ) +class BeforeSubEnvironmentResetCallback(DefaultCallbacks): + def __init__(self): + super().__init__() + self._reset_counter = 0 + + def before_sub_environment_reset( + self, *, worker, sub_environment, env_index, **kwargs + ): + print(f"Sub-env {env_index} is going to be reset.") + self._reset_counter += 1 + + class TestCallbacks(unittest.TestCase): @classmethod def setUpClass(cls): @@ -102,6 +115,41 @@ def test_on_sub_environment_created_with_remote_envs(self): self.assertTrue(sum_sub_env_vector_indices[2] == 6) algo.stop() + def test_before_sub_environment_reset(self): + # 1000 steps sampled (2.5 episodes on each sub-environment) before training + # starts. + config = ( + dqn.DQNConfig() + .environment( + RandomEnv, + env_config={ + "max_episode_len": 200, + "p_done": 0.0, + }, + ) + .rollouts(num_envs_per_worker=2, num_rollout_workers=1) + .callbacks(BeforeSubEnvironmentResetCallback) + ) + + for _ in framework_iterator(config, frameworks=("tf", "torch")): + algo = config.build() + algo.train() + # Two sub-environments share 1000 steps in the first training iteration + # (min_sample_timesteps_per_iteration = 1000). + # -> 1000 / 2 [sub-envs] = 500 [per sub-env] + # -> 1 episode = 200 timesteps + # -> 2.5 episodes per sub-env + # -> 3 resets [per sub-env] = 6 resets total + self.assertTrue( + 6 + == ray.get( + algo.workers.remote_workers()[0].apply.remote( + lambda w: w.callbacks._reset_counter + ) + ) + ) + algo.stop() + if __name__ == "__main__": import pytest diff --git a/rllib/evaluation/env_runner_v2.py b/rllib/evaluation/env_runner_v2.py index 5acb56912424..1490ea652e7c 100644 --- a/rllib/evaluation/env_runner_v2.py +++ b/rllib/evaluation/env_runner_v2.py @@ -393,6 +393,17 @@ def run(self) -> Iterator[SampleBatchType]: Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ + # Before the very first poll (this will reset all vector sub-environments): + # Call custom `before_sub_environment_reset` callbacks for all sub-environments. + for env_id, sub_env in self._base_env.get_sub_environments( + as_dict=True + ).items(): + self._callbacks.before_sub_environment_reset( + worker=self._worker, + sub_environment=sub_env, + env_index=env_id, + ) + while True: self._perf_stats.incr("iters", 1) @@ -773,6 +784,15 @@ def _handle_done_episode( # Basically carry RNN and other buffered state to the # next episode from the same env. else: + # Call custom `before_sub_environment_reset` callback. + self._callbacks.before_sub_environment_reset( + worker=self._worker, + sub_environment=self._base_env.get_sub_environments(as_dict=True)[ + env_id + ], + env_index=env_id, + ) + # TODO(jungong) : This will allow a single faulty env to # take out the entire RolloutWorker indefinitely. Revisit. while True: diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 06650bc3fe03..f9bd09bc8dd0 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -664,6 +664,15 @@ def new_episode(env_id): active_episodes: Dict[EnvID, Episode] = _NewEpisodeDefaultDict(new_episode) + # Before the very first poll (this will reset all vector sub-environments): + # Call custom `before_sub_environment_reset` callbacks for all sub-environments. + for env_id, sub_env in base_env.get_sub_environments(as_dict=True).items(): + callbacks.before_sub_environment_reset( + worker=worker, + sub_environment=sub_env, + env_index=env_id, + ) + while True: perf_stats.incr("iters", 1) @@ -1080,6 +1089,16 @@ def _process_observations( } else: del active_episodes[env_id] + + # Call custom `before_sub_environment_reset` callback. + sub_envs = base_env.get_sub_environments(as_dict=True) + if env_id in sub_envs: + callbacks.before_sub_environment_reset( + worker=worker, + sub_environment=sub_envs[env_id], + env_index=env_id, + ) + # TODO(jungong) : This will allow a single faulty env to # take out the entire RolloutWorker indefinitely. Revisit. while True: