Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complete_episodes breaks custom callback example [Bug] #22683

Closed
2 tasks done
simonsays1980 opened this issue Feb 27, 2022 · 1 comment · Fixed by #22900
Closed
2 tasks done

complete_episodes breaks custom callback example [Bug] #22683

simonsays1980 opened this issue Feb 27, 2022 · 1 comment · Fixed by #22900
Labels
bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@simonsays1980
Copy link
Collaborator

Search before asking

  • I searched the issues and found no similar issues.

Ray Component

RLlib

What happened + What you expected to happen

What happened

I used the custom_metrics_and_callbacks.py example to build my own callbacks for custom metrics. I am using batch_mode="complete_episodes" which resulted in an error that was not easily traceable and solvable.

What I expected to happen

That the example can be used as a guideline to build custom metrics without producing errors that are not easily traceable.

Versions / Dependencies

ray 1.10.0
Python 3.9.0
Fedora Linux 35

Reproduction script

"""Example of using RLlib's debug callbacks.

Here we use callbacks to track the average CartPole pole angle magnitude as a
custom metric.
"""

from typing import Dict
import argparse
import numpy as np
import os

import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch

parser = argparse.ArgumentParser()
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "tfe", "torch"],
    default="tf",
    help="The DL framework specifier.",
)
parser.add_argument("--stop-iters", type=int, default=2000)


class MyCallbacks(DefaultCallbacks):
    def on_episode_start(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: Episode,
        env_index: int,
        **kwargs
    ):
        # Make sure this episode has just been started (only initial obs
        # logged so far).
        assert episode.length == 0, (
            "ERROR: `on_episode_start()` callback should be called right "
            "after env reset!"
        )
        print("episode {} (env-idx={}) started.".format(episode.episode_id, env_index))
        episode.user_data["pole_angles"] = []
        episode.hist_data["pole_angles"] = []

    def on_episode_step(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: Episode,
        env_index: int,
        **kwargs
    ):
        # Make sure this episode is ongoing.
        assert episode.length > 0, (
            "ERROR: `on_episode_step()` callback should not be called right "
            "after env reset!"
        )
        pole_angle = abs(episode.last_observation_for()[2])
        raw_angle = abs(episode.last_raw_obs_for()[2])
        assert pole_angle == raw_angle
        episode.user_data["pole_angles"].append(pole_angle)

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: Episode,
        env_index: int,
        **kwargs
    ):
        # Check if there are multiple episodes in a batch, i.e.
        # "batch_mode":"truncate_episodes".
        if worker.sampler.sample_collector.multiple_episodes_in_batch:
            # Make sure this episode is really done.
            assert episode.batch_builder.policy_collectors["default_policy"].batches[
                -1
            ]["dones"][-1], (
                "ERROR: `on_episode_end()` should only be called "
                "after episode is done!"
            )
        pole_angle = np.mean(episode.user_data["pole_angles"])
        print(
            "episode {} (env-idx={}) ended with length {} and pole "
            "angles {}".format(
                episode.episode_id, env_index, episode.length, pole_angle
            )
        )
        episode.custom_metrics["pole_angle"] = pole_angle
        episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]

    def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch, **kwargs):
        print("returned sample batch of size {}".format(samples.count))

    def on_train_result(self, *, trainer, result: dict, **kwargs):
        print(
            "trainer.train() result: {} -> {} episodes".format(
                trainer, result["episodes_this_iter"]
            )
        )
        # you can mutate the result dict to add new fields to return
        result["callback_ok"] = True

    def on_learn_on_batch(
        self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
    ) -> None:
        result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
        print(
            "policy.learn_on_batch() result: {} -> sum actions: {}".format(
                policy, result["sum_actions_in_train_batch"]
            )
        )

    def on_postprocess_trajectory(
        self,
        *,
        worker: RolloutWorker,
        episode: Episode,
        agent_id: str,
        policy_id: str,
        policies: Dict[str, Policy],
        postprocessed_batch: SampleBatch,
        original_batches: Dict[str, SampleBatch],
        **kwargs
    ):
        print("postprocessed {} steps".format(postprocessed_batch.count))
        if "num_batches" not in episode.custom_metrics:
            episode.custom_metrics["num_batches"] = 0
        episode.custom_metrics["num_batches"] += 1


if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(local_mode=True)
    trials = tune.run(
        "PG",
        stop={
            "training_iteration": args.stop_iters,
        },
        config={
            "env": "CartPole-v0",
            "num_envs_per_worker": 2,
            "callbacks": MyCallbacks,
            "framework": args.framework,
            # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
            "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
            "batch_mode": "complete_episodes",
        },
    ).trials

    # Verify episode-related custom metrics are there.
    custom_metrics = trials[0].last_result["custom_metrics"]
    print(custom_metrics)
    assert "pole_angle_mean" in custom_metrics
    assert "pole_angle_min" in custom_metrics
    assert "pole_angle_max" in custom_metrics
    assert "num_batches_mean" in custom_metrics
    assert "callback_ok" in trials[0].last_result

    # Verify `on_learn_on_batch` custom metrics are there (per policy).
    if args.framework == "torch":
        info_custom_metrics = custom_metrics["default_policy"]
        print(info_custom_metrics)
        assert "sum_actions_in_train_batch" in info_custom_metrics

Anything else

This is no actual bug, but also no request for support or a feature request. It is a note towards a better readable of code and a better traceability and for end-users for a better guideline that leads them to their goal.

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!
@simonsays1980 simonsays1980 added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Feb 27, 2022
@simonsays1980
Copy link
Collaborator Author

@gjoliver @sven1977 #22684 should close this issue. And #22900 is an improvement containing 1 of the two suggestions made by @sven1977 in #22684.

@sven1977 The second suggestion has not made it into the PR as @gjoliver saw possible compatability problems with multiple_episodes_in_batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant