Skip to content

Commit

Permalink
[RLlib] Enhancements for multi-node/multi-GPU training and better Env…
Browse files Browse the repository at this point in the history
…Runner error msg. (ray-project#47705)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent ad10830 commit bec95ad
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
10 changes: 8 additions & 2 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,16 +828,22 @@ def make_env(self):
remote=self.config.remote_worker_envs,
)

# No env provided -> Error.
if not self.config.env:
raise ValueError(
"`config.env` is not provided! You should provide a valid environment "
"to your config through `config.environment([env descriptor e.g. "
"'CartPole-v1'])`."
)
# Register env for the local context.
# Note, `gym.register` has to be called on each worker.
if isinstance(self.config.env, str) and _global_registry.contains(
elif isinstance(self.config.env, str) and _global_registry.contains(
ENV_CREATOR, self.config.env
):
entry_point = partial(
_global_registry.get(ENV_CREATOR, self.config.env),
env_ctx,
)

else:
entry_point = partial(
_gym_env_creator,
Expand Down
10 changes: 8 additions & 2 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,22 @@ def make_env(self) -> None:
remote=self.config.remote_worker_envs,
)

# No env provided -> Error.
if not self.config.env:
raise ValueError(
"`config.env` is not provided! You should provide a valid environment "
"to your config through `config.environment([env descriptor e.g. "
"'CartPole-v1'])`."
)
# Register env for the local context.
# Note, `gym.register` has to be called on each worker.
if isinstance(self.config.env, str) and _global_registry.contains(
elif isinstance(self.config.env, str) and _global_registry.contains(
ENV_CREATOR, self.config.env
):
entry_point = partial(
_global_registry.get(ENV_CREATOR, self.config.env),
env_ctx,
)

else:
entry_point = partial(
_gym_env_creator,
Expand Down
14 changes: 9 additions & 5 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,15 +1392,19 @@ def run_rllib_example_script_experiment(
# Define compute resources used automatically (only using the --num-gpus arg).
# New stack.
if config.enable_rl_module_and_learner:
# Do we have GPUs available in the cluster?
num_gpus = ray.cluster_resources().get("GPU", 0)
if args.num_gpus > 0 and num_gpus < args.num_gpus:
logger.warning(
f"You are running your script with --num-gpus={args.num_gpus}, "
f"but your cluster only has {num_gpus} GPUs! Will run "
f"with {num_gpus} CPU Learners instead."
)
# Define compute resources used.
config.resources(num_gpus=0)
config.learners(
num_learners=args.num_gpus,
num_gpus_per_learner=(
1
if torch and torch.cuda.is_available() and args.num_gpus > 0
else 0
),
num_gpus_per_learner=1 if num_gpus >= args.num_gpus > 0 else 0,
)
config.resources(num_gpus=0)
# Old stack.
Expand Down

0 comments on commit bec95ad

Please sign in to comment.