Skip to content

Commit

Permalink
[RLlib] Enhance env-rendering callback. (#45682)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Jun 3, 2024
1 parent 6147b15 commit e6e21ac
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions rllib/examples/envs/env_rendering_and_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"""
import gymnasium as gym
import numpy as np
from typing import Optional, Sequence

from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack
Expand All @@ -85,8 +86,10 @@ class EnvRenderCallback(DefaultCallbacks):
and temporarily store it in the Episode object.
"""

def __init__(self):
def __init__(self, env_runner_indices: Optional[Sequence[int]] = None):
super().__init__()
# Only render and record on certain EnvRunner indices?
self.env_runner_indices = env_runner_indices
# Per sample round (on this EnvRunner), we want to only log the best- and
# worst performing episode's videos in the custom metrics. Otherwise, too much
# data would be sent to WandB.
Expand All @@ -108,6 +111,12 @@ def on_episode_step(
Note that this would work with MultiAgentEpisodes as well.
"""
if (
self.env_runner_indices is not None
and env_runner.worker_index not in self.env_runner_indices
):
return

# If we have a vector env, only render the sub-env at index 0.
if isinstance(env.unwrapped, gym.vector.VectorEnv):
image = env.envs[0].render()
Expand Down Expand Up @@ -184,30 +193,31 @@ def on_sample_end(
) -> None:
"""Logs the best and worst video to this EnvRunner's MetricsLogger."""
# Best video.
metrics_logger.log_value(
"episode_videos_best",
self.best_episode_and_return[0],
# Do not reduce the videos (across the various parallel EnvRunners). This
# would not make sense (mean over the pixels?). Instead, we want to log all
# best videos of all EnvRunners per iteration.
reduce=None,
# B/c we do NOT reduce over the video data (mean/min/max), we need to make
# sure the list of videos in our MetricsLogger does not grow infinitely and
# gets cleared after each `reduce()` operation, meaning every time, the
# EnvRunner is asked to send its logged metrics.
clear_on_reduce=True,
)
if self.best_episode_and_return[0] is not None:
metrics_logger.log_value(
"episode_videos_best",
self.best_episode_and_return[0],
# Do not reduce the videos (across the various parallel EnvRunners).
# This would not make sense (mean over the pixels?). Instead, we want to
# log all best videos of all EnvRunners per iteration.
reduce=None,
# B/c we do NOT reduce over the video data (mean/min/max), we need to
# make sure the list of videos in our MetricsLogger does not grow
# infinitely and gets cleared after each `reduce()` operation, meaning
# every time, the EnvRunner is asked to send its logged metrics.
clear_on_reduce=True,
)
self.best_episode_and_return = (None, float("-inf"))
# Worst video.
metrics_logger.log_value(
"episode_videos_worst",
self.worst_episode_and_return[0],
# Same logging options as above.
reduce=None,
clear_on_reduce=True,
)
# Reset our best/worst placeholders.
self.best_episode_and_return = (None, float("-inf"))
self.worst_episode_and_return = (None, float("inf"))
if self.worst_episode_and_return[0] is not None:
metrics_logger.log_value(
"episode_videos_worst",
self.worst_episode_and_return[0],
# Same logging options as above.
reduce=None,
clear_on_reduce=True,
)
self.worst_episode_and_return = (None, float("inf"))


if __name__ == "__main__":
Expand Down

0 comments on commit e6e21ac

Please sign in to comment.