Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] PyFlyt QuadX WayPoints #44758

Merged
merged 19 commits into from
Jun 6, 2024
125 changes: 125 additions & 0 deletions rllib/examples/quadx_waypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Example using the PyFlyt Gymnasium environment to train a UAV to reach waypoints.
peterghaddad marked this conversation as resolved.
Show resolved Hide resolved

PyFlyt GitHub Repository: https://github.com/jjshoots/PyFlyt/tree/master/PyFlyt

How to run this script
----------------------
`python [script file name].py --enable-new-api-stack`


For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.

For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
"""

import os

from ray.tune.registry import get_trainable_cls
import gymnasium as gym
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
TRAINING_ITERATION_TIMER,
peterghaddad marked this conversation as resolved.
Show resolved Hide resolved
)

parser = add_rllib_example_script_args(
default_iters=200,
default_timesteps=100000,
default_reward=90.0,
)
parser.add_argument(
"--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
)
parser.add_argument("--env-name", type=str, default="quadx_waypoints")
parser.add_argument("--num-envs-per-worker", type=int, default=4)


class RewardWrapper(gym.RewardWrapper):
peterghaddad marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, env):
super().__init__(env)

def reward(self, reward):
# Scale rewards:
if reward >= 99.0 or reward <= -99.0:
return reward / 10
return reward


def create_quadx_waypoints_env(env_config):
import PyFlyt.gym_envs # noqa
from PyFlyt.gym_envs import FlattenWaypointEnv

env = gym.make("PyFlyt/QuadX-Waypoints-v1")
# Wrap Environment to use max 10 and -10 for rewards
env = RewardWrapper(env)

return FlattenWaypointEnv(env, context_length=1)


if __name__ == "__main__":
from ray.tune.registry import register_env

args = parser.parse_args()
num_gpus = int(os.environ.get("RLLIB_NUM_GPUS", "0"))

register_env(args.env_name, env_creator=create_quadx_waypoints_env)

algo_cls = get_trainable_cls(args.run)
config = algo_cls.get_default_config()

config.environment(env=args.env_name).resources(
num_learner_workers=num_gpus,
num_gpus_per_learner_worker=num_gpus,
).rollouts(
num_rollout_workers=args.num_cpus,
num_envs_per_worker=args.num_envs_per_worker,
).framework(
args.framework
).api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
).reporting(
min_time_s_per_iteration=0.1
)
peterghaddad marked this conversation as resolved.
Show resolved Hide resolved

if args.run == "PPO":
config.rl_module(
model_config_dict={
"fcnet_hiddens": [32],
"fcnet_activation": "linear",
"vf_share_layers": True,
}
)
config.training(
sgd_minibatch_size=128,
train_batch_size=10000,
)
elif args.run == "IMPALA":
config.rollouts(num_rollout_workers=2)
config.resources(num_gpus=0)
config.training(vf_loss_coeff=0.01)

EPISODE_RETURN_MEAN_KEY = f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"

stop = {
TRAINING_ITERATION_TIMER: args.stop_iters,
EPISODE_RETURN_MEAN_KEY: args.stop_reward,
}
peterghaddad marked this conversation as resolved.
Show resolved Hide resolved

run_rllib_example_script_experiment(
config,
args,
stop=stop,
success_metric={
EPISODE_RETURN_MEAN_KEY: args.stop_reward,
},
)
Loading