Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 01 (some preparatory cleanups). #47884

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 35 additions & 115 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,6 @@ def setup(self, config: AlgorithmConfig) -> None:
)
and self.config.input_ != "sampler"
and self.config.enable_rl_module_and_learner
and self.config.enable_env_runner_and_connector_v2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: we don 't need these double checks anymore b/c the hybrid stack has been deprecated already (users will get error message in the config.validate() call).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's remove it!

)
else self.config.num_env_runners
),
Expand All @@ -707,7 +706,6 @@ def setup(self, config: AlgorithmConfig) -> None:
)
and self.config.input_ != "sampler"
and self.config.enable_rl_module_and_learner
and self.config.enable_env_runner_and_connector_v2
):
from ray.rllib.offline.offline_data import OfflineData

Expand Down Expand Up @@ -797,20 +795,10 @@ def setup(self, config: AlgorithmConfig) -> None:
method_config["type"] = method_type

if self.config.enable_rl_module_and_learner:
if self.config.enable_env_runner_and_connector_v2:
module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec(
spaces=self.env_runner_group.get_spaces(),
inference_only=False,
)
# TODO (Sven): Deprecate this path: Old stack API RolloutWorkers and
# DreamerV3's EnvRunners have a `multi_rl_module_spec` property.
elif hasattr(self.env_runner, "multi_rl_module_spec"):
module_spec: MultiRLModuleSpec = self.env_runner.multi_rl_module_spec
else:
raise AttributeError(
"Your local EnvRunner/RolloutWorker does NOT have any property "
"referring to its RLModule!"
)
module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec(
spaces=self.env_runner_group.get_spaces(),
inference_only=False,
)
self.learner_group = self.config.build_learner_group(
rl_module_spec=module_spec
)
Expand All @@ -829,36 +817,19 @@ def setup(self, config: AlgorithmConfig) -> None:
rl_module_ckpt_dirs=rl_module_ckpt_dirs,
)

# Only when using RolloutWorkers: Update also the worker set's
# `is_policy_to_train`.
# Note that with the new EnvRunner API in combination with the new stack,
# this information only needs to be kept in the Learner and not on the
# EnvRunners anymore.
if not self.config.enable_env_runner_and_connector_v2:
policies_to_train = self.config.policies_to_train or set(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need this still in the old stack?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code will never get there anymore because users on the new API stack will call add_module and remove_module (instead of add_policy and remove_policy). But you are right, we should error out here directly if the config says otherwise. I'll add these exceptions ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, these errors are already there.

self.config.policies
)
self.env_runner_group.foreach_worker(
lambda w: w.set_is_policy_to_train(policies_to_train),
)
# Sync the weights from the learner group to the rollout workers.
self.env_runner.set_weights(self.learner_group.get_weights())
self.env_runner_group.sync_weights(inference_only=True)
# New stack/EnvRunner APIs: Use get/set_state.
else:
# Sync the weights from the learner group to the EnvRunners.
rl_module_state = self.learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
inference_only=True,
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state})
self.env_runner_group.sync_env_runner_states(
config=self.config,
env_steps_sampled=self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0
),
rl_module_state=rl_module_state,
)
# Sync the weights from the learner group to the EnvRunners.
rl_module_state = self.learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
inference_only=True,
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state})
self.env_runner_group.sync_env_runner_states(
config=self.config,
env_steps_sampled=self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0
),
rl_module_state=rl_module_state,
)

if self.offline_data:
# If the learners are remote we need to provide specific
Expand Down Expand Up @@ -1716,53 +1687,45 @@ def training_step(self) -> ResultDict:
"code and delete this error message)."
)

# Collect SampleBatches from sample workers until we have a full batch.
# Collect a list of Episodes from EnvRunners until we reach the train batch
# size.
with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
if self.config.count_steps_by == "agent_steps":
train_batch, env_runner_results = synchronous_parallel_sample(
episodes, env_runner_results = synchronous_parallel_sample(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bug

worker_set=self.env_runner_group,
max_agent_steps=self.config.total_train_batch_size,
sample_timeout_s=self.config.sample_timeout_s,
_uses_new_env_runners=(
self.config.enable_env_runner_and_connector_v2
),
_uses_new_env_runners=True,
_return_metrics=True,
)
else:
train_batch, env_runner_results = synchronous_parallel_sample(
episodes, env_runner_results = synchronous_parallel_sample(
worker_set=self.env_runner_group,
max_env_steps=self.config.total_train_batch_size,
sample_timeout_s=self.config.sample_timeout_s,
_uses_new_env_runners=(
self.config.enable_env_runner_and_connector_v2
),
_uses_new_env_runners=True,
_return_metrics=True,
)
train_batch = train_batch.as_multi_agent()

# Reduce EnvRunner metrics over the n EnvRunners.
self.metrics.merge_and_log_n_dicts(env_runner_results, key=ENV_RUNNER_RESULTS)

# Only train if train_batch is not empty.
# In an extreme situation, all rollout workers die during the
# synchronous_parallel_sample() call above.
# In which case, we should skip training, wait a little bit, then probe again.
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
if train_batch.agent_steps() > 0:
learner_results = self.learner_group.update_from_batch(
batch=train_batch
)
self.metrics.log_dict(learner_results, key=LEARNER_RESULTS)
else:
# Wait 1 sec before probing again via weight syncing.
time.sleep(1.0)
learner_results = self.learner_group.update_from_episodes(
episodes=episodes,
timesteps={
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
),
},
)
self.metrics.log_dict(learner_results, key=LEARNER_RESULTS)

# Update weights - after learning on the local worker - on all
# remote workers (only those RLModules that were actually trained).
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.env_runner_group.sync_weights(
from_worker_or_learner_group=self.learner_group,
policies=set(learner_results.keys()) - {ALL_MODULES},
policies=list(set(learner_results.keys()) - {ALL_MODULES}),
inference_only=True,
)

Expand Down Expand Up @@ -2492,29 +2455,6 @@ def add_policy(
module_spec=module_spec,
)

# If Learner API is enabled, we need to also add the underlying module
# to the learner group.
if add_to_learners and self.config.enable_rl_module_and_learner:
policy = self.get_policy(policy_id)
module = policy.model
self.learner_group.add_module(
module_id=policy_id,
module_spec=RLModuleSpec.from_module(module),
)

# Update each Learner's `policies_to_train` information, but only
# if the arg is explicitly provided here.
if policies_to_train is not None:
self.learner_group.foreach_learner(
func=lambda learner: learner.config.multi_agent(
policies_to_train=policies_to_train
),
timeout_seconds=0.0, # fire-and-forget
)

weights = policy.get_weights()
self.learner_group.set_weights({policy_id: weights})

# Add to evaluation workers, if necessary.
if add_to_eval_env_runners is True and self.eval_env_runner_group is not None:
self.eval_env_runner_group.add_policy(
Expand Down Expand Up @@ -2598,20 +2538,6 @@ def fn(worker):
if remove_from_env_runners:
self.env_runner_group.foreach_worker(fn, local_env_runner=True)

# Update each Learner's `policies_to_train` information, but only
# if the arg is explicitly provided here.
if (
remove_from_learners
and self.config.enable_rl_module_and_learner
and policies_to_train is not None
):
self.learner_group.foreach_learner(
func=lambda learner: learner.config.multi_agent(
policies_to_train=policies_to_train
),
timeout_seconds=0.0, # fire-and-forget
)

# Update the evaluation worker set's workers, if required.
if remove_from_eval_env_runners and self.eval_env_runner_group is not None:
self.eval_env_runner_group.foreach_worker(fn, local_env_runner=True)
Expand Down Expand Up @@ -2705,10 +2631,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None:
"""
# New API stack: Delegate to the `Checkpointable` implementation of
# `save_to_path()`.
if (
self.config.enable_rl_module_and_learner
and self.config.enable_env_runner_and_connector_v2
):
if self.config.enable_rl_module_and_learner:
return self.save_to_path(checkpoint_dir)

checkpoint_dir = pathlib.Path(checkpoint_dir)
Expand Down Expand Up @@ -2770,10 +2693,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None:
def load_checkpoint(self, checkpoint_dir: str) -> None:
# New API stack: Delegate to the `Checkpointable` implementation of
# `restore_from_path()`.
if (
self.config.enable_rl_module_and_learner
and self.config.enable_env_runner_and_connector_v2
):
if self.config.enable_rl_module_and_learner:
self.restore_from_path(checkpoint_dir)

# Call the `on_checkpoint_loaded` callback.
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/connector_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def single_agent_episode_iterator(
list_indices = defaultdict(int)

# Single-agent case.
if isinstance(episodes[0], SingleAgentEpisode):
if episodes and isinstance(episodes[0], SingleAgentEpisode):
if zip_with_batch_column is not None:
if len(zip_with_batch_column) != len(episodes):
raise ValueError(
Expand Down
10 changes: 0 additions & 10 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
from ray.rllib.utils.deprecation import deprecation_warning

from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
MultiRLModuleConfig,
)


MultiAgentRLModule = MultiRLModule
MultiAgentRLModuleConfig = MultiRLModuleConfig
MultiAgentRLModuleSpec = MultiRLModuleSpec

deprecation_warning(
old="ray.rllib.core.rl_module.marl_module",
Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/bc/cartpole_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
)

stop = {
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0,
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0,
TRAINING_ITERATION: 350,
}

Expand Down
3 changes: 1 addition & 2 deletions rllib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
LEARNER_RESULTS = "learners"
FAULT_TOLERANCE_STATS = "fault_tolerance"
TIMERS = "timers"
# ALGORITHM_RESULTS = "algorithm"

# RLModule metrics
# RLModule metrics.
NUM_TRAINABLE_PARAMETERS = "num_trainable_parameters"
NUM_NON_TRAINABLE_PARAMETERS = "num_non_trainable_parameters"

Expand Down
2 changes: 2 additions & 0 deletions rllib/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def gym_space_to_dict(space: gym.spaces.Space) -> Dict:
Returns:
Serialized JSON string.
"""
if space is None:
return None

def _box(sp: gym.spaces.Box) -> Dict:
return {
Expand Down
15 changes: 8 additions & 7 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import argparse
from collections import Counter
import copy
import gymnasium as gym
from gymnasium.spaces import Box, Discrete, MultiDiscrete, MultiBinary
from gymnasium.spaces import Dict as GymDict
from gymnasium.spaces import Tuple as GymTuple
import json
import logging
import numpy as np
import os
import pprint
import random
import re
import time
import tree # pip install dm_tree
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -26,11 +20,19 @@
)
import yaml

import gymnasium as gym
from gymnasium.spaces import Box, Discrete, MultiDiscrete, MultiBinary
from gymnasium.spaces import Dict as GymDict
from gymnasium.spaces import Tuple as GymTuple
import numpy as np
import tree # pip install dm_tree

import ray
from ray import air, tune
from ray.air.constants import TRAINING_ITERATION
from ray.air.integrations.wandb import WandbLoggerCallback, WANDB_ENV_VAR
from ray.rllib.common import SupportedFileType
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
from ray.rllib.train import load_experiments_from_file
from ray.rllib.utils.annotations import OldAPIStack
Expand Down Expand Up @@ -789,7 +791,6 @@ def check_train_results_new_api_stack(train_results: ResultDict) -> None:
data in it.
"""
# Import these here to avoid circular dependencies.
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
FAULT_TOLERANCE_STATS,
Expand Down
Loading