Skip to content

Commit

Permalink
[RLlib] Add 2D box example for PPO RL Modules (ray-project#33840)
Browse files Browse the repository at this point in the history
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored and vitsai committed Apr 17, 2023
1 parent 81b678b commit a659e23
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
8 changes: 8 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3302,6 +3302,14 @@ py_test(
args = ["--stop-iters=2"]
)

py_test(
name = "examples/env/greyscale_env",
tags = ["team:rllib", "examples", "no_main"],
size = "medium",
srcs = ["examples/env/greyscale_env.py"],
args = ["--stop-iters=1 --as-test --framework torch"]
)

py_test(
name = "examples/export/cartpole_dqn_export",
main = "examples/export/cartpole_dqn_export.py",
Expand Down
120 changes: 120 additions & 0 deletions rllib/examples/env/greyscale_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Example of interfacing with an environment that produces 2D observations.
This example shows how turning 2D observations with shape (A, B) into a 3D
observations with shape (C, D, 1) can enable usage of RLlib's default models.
RLlib's default Catalog class does not provide default models for 2D observation
spaces, but it does so for 3D observations.
Therefore, one can either write a custom model or transform the 2D observations into 3D
observations. This enables RLlib to use one of the default CNN filters, even though the
original observation space of the environment does not fit them.
This simple example should reach rewards of 50 within 150k timesteps.
"""

from numpy import float32
import argparse
from pettingzoo.butterfly import pistonball_v6
from supersuit import (
normalize_obs_v0,
dtype_v0,
color_reduction_v0,
reshape_v0,
resize_v1,
)

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env import PettingZooEnv
from ray.tune.registry import register_env
from ray import tune
from ray import air


parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a compilation test.",
)
parser.add_argument(
"--stop-iters", type=int, default=150, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=1000000, help="Number of timesteps to train."
)
parser.add_argument(
"--stop-reward", type=float, default=50, help="Reward at which we stop training."
)

args = parser.parse_args()


# The space we down-sample and transform the greyscale pistonball images to.
# Other spaces supported by RLlib can be chosen here.
TRANSFORMED_OBS_SPACE = (42, 42, 1)


def env_creator(config):
env = pistonball_v6.env(n_pistons=5)
env = dtype_v0(env, dtype=float32)
# This gives us greyscale images for the color red
env = color_reduction_v0(env, mode="R")
env = normalize_obs_v0(env)
# This gives us images that are upsampled to the number of pixels in the
# default CNN filter
env = resize_v1(
env, x_size=TRANSFORMED_OBS_SPACE[0], y_size=TRANSFORMED_OBS_SPACE[1]
)
# This gives us 3D images for which we have default filters
env = reshape_v0(env, shape=TRANSFORMED_OBS_SPACE)
return env


# Register env
register_env("pistonball", lambda config: PettingZooEnv(env_creator(config)))

config = (
PPOConfig()
.environment("pistonball", env_config={"local_ratio": 0.5}, clip_rewards=True)
.rollouts(
num_rollout_workers=15 if not args.as_test else 2,
num_envs_per_worker=1,
observation_filter="NoFilter",
rollout_fragment_length="auto",
)
.framework("torch")
.training(
entropy_coeff=0.01,
vf_loss_coeff=0.1,
clip_param=0.1,
vf_clip_param=10.0,
num_sgd_iter=10,
kl_coeff=0.5,
lr=0.0001,
grad_clip=100,
sgd_minibatch_size=500,
train_batch_size=5000,
model={"vf_share_layers": True},
)
.resources(num_gpus=1 if not args.as_test else 0)
.reporting(min_time_s_per_iteration=30)
)

tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
},
verbose=2,
),
).fit()

0 comments on commit a659e23

Please sign in to comment.