-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New API stack: (Multi)RLModule overhaul vol 01 (some preparatory cleanups). #47884
Changes from all commits
5ee4234
73c5d65
0d52c1d
ed960ea
f3616e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -685,7 +685,6 @@ def setup(self, config: AlgorithmConfig) -> None: | |
) | ||
and self.config.input_ != "sampler" | ||
and self.config.enable_rl_module_and_learner | ||
and self.config.enable_env_runner_and_connector_v2 | ||
) | ||
else self.config.num_env_runners | ||
), | ||
|
@@ -707,7 +706,6 @@ def setup(self, config: AlgorithmConfig) -> None: | |
) | ||
and self.config.input_ != "sampler" | ||
and self.config.enable_rl_module_and_learner | ||
and self.config.enable_env_runner_and_connector_v2 | ||
): | ||
from ray.rllib.offline.offline_data import OfflineData | ||
|
||
|
@@ -797,20 +795,10 @@ def setup(self, config: AlgorithmConfig) -> None: | |
method_config["type"] = method_type | ||
|
||
if self.config.enable_rl_module_and_learner: | ||
if self.config.enable_env_runner_and_connector_v2: | ||
module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( | ||
spaces=self.env_runner_group.get_spaces(), | ||
inference_only=False, | ||
) | ||
# TODO (Sven): Deprecate this path: Old stack API RolloutWorkers and | ||
# DreamerV3's EnvRunners have a `multi_rl_module_spec` property. | ||
elif hasattr(self.env_runner, "multi_rl_module_spec"): | ||
module_spec: MultiRLModuleSpec = self.env_runner.multi_rl_module_spec | ||
else: | ||
raise AttributeError( | ||
"Your local EnvRunner/RolloutWorker does NOT have any property " | ||
"referring to its RLModule!" | ||
) | ||
module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( | ||
spaces=self.env_runner_group.get_spaces(), | ||
inference_only=False, | ||
) | ||
self.learner_group = self.config.build_learner_group( | ||
rl_module_spec=module_spec | ||
) | ||
|
@@ -829,36 +817,19 @@ def setup(self, config: AlgorithmConfig) -> None: | |
rl_module_ckpt_dirs=rl_module_ckpt_dirs, | ||
) | ||
|
||
# Only when using RolloutWorkers: Update also the worker set's | ||
# `is_policy_to_train`. | ||
# Note that with the new EnvRunner API in combination with the new stack, | ||
# this information only needs to be kept in the Learner and not on the | ||
# EnvRunners anymore. | ||
if not self.config.enable_env_runner_and_connector_v2: | ||
policies_to_train = self.config.policies_to_train or set( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need this still in the old stack? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code will never get there anymore because users on the new API stack will call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, these errors are already there. |
||
self.config.policies | ||
) | ||
self.env_runner_group.foreach_worker( | ||
lambda w: w.set_is_policy_to_train(policies_to_train), | ||
) | ||
# Sync the weights from the learner group to the rollout workers. | ||
self.env_runner.set_weights(self.learner_group.get_weights()) | ||
self.env_runner_group.sync_weights(inference_only=True) | ||
# New stack/EnvRunner APIs: Use get/set_state. | ||
else: | ||
# Sync the weights from the learner group to the EnvRunners. | ||
rl_module_state = self.learner_group.get_state( | ||
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, | ||
inference_only=True, | ||
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE] | ||
self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state}) | ||
self.env_runner_group.sync_env_runner_states( | ||
config=self.config, | ||
env_steps_sampled=self.metrics.peek( | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 | ||
), | ||
rl_module_state=rl_module_state, | ||
) | ||
# Sync the weights from the learner group to the EnvRunners. | ||
rl_module_state = self.learner_group.get_state( | ||
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, | ||
inference_only=True, | ||
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE] | ||
self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state}) | ||
self.env_runner_group.sync_env_runner_states( | ||
config=self.config, | ||
env_steps_sampled=self.metrics.peek( | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 | ||
), | ||
rl_module_state=rl_module_state, | ||
) | ||
|
||
if self.offline_data: | ||
# If the learners are remote we need to provide specific | ||
|
@@ -1716,53 +1687,45 @@ def training_step(self) -> ResultDict: | |
"code and delete this error message)." | ||
) | ||
|
||
# Collect SampleBatches from sample workers until we have a full batch. | ||
# Collect a list of Episodes from EnvRunners until we reach the train batch | ||
# size. | ||
with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)): | ||
if self.config.count_steps_by == "agent_steps": | ||
train_batch, env_runner_results = synchronous_parallel_sample( | ||
episodes, env_runner_results = synchronous_parallel_sample( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was a bug |
||
worker_set=self.env_runner_group, | ||
max_agent_steps=self.config.total_train_batch_size, | ||
sample_timeout_s=self.config.sample_timeout_s, | ||
_uses_new_env_runners=( | ||
self.config.enable_env_runner_and_connector_v2 | ||
), | ||
_uses_new_env_runners=True, | ||
_return_metrics=True, | ||
) | ||
else: | ||
train_batch, env_runner_results = synchronous_parallel_sample( | ||
episodes, env_runner_results = synchronous_parallel_sample( | ||
worker_set=self.env_runner_group, | ||
max_env_steps=self.config.total_train_batch_size, | ||
sample_timeout_s=self.config.sample_timeout_s, | ||
_uses_new_env_runners=( | ||
self.config.enable_env_runner_and_connector_v2 | ||
), | ||
_uses_new_env_runners=True, | ||
_return_metrics=True, | ||
) | ||
train_batch = train_batch.as_multi_agent() | ||
|
||
# Reduce EnvRunner metrics over the n EnvRunners. | ||
self.metrics.merge_and_log_n_dicts(env_runner_results, key=ENV_RUNNER_RESULTS) | ||
|
||
# Only train if train_batch is not empty. | ||
# In an extreme situation, all rollout workers die during the | ||
# synchronous_parallel_sample() call above. | ||
# In which case, we should skip training, wait a little bit, then probe again. | ||
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): | ||
if train_batch.agent_steps() > 0: | ||
learner_results = self.learner_group.update_from_batch( | ||
batch=train_batch | ||
) | ||
self.metrics.log_dict(learner_results, key=LEARNER_RESULTS) | ||
else: | ||
# Wait 1 sec before probing again via weight syncing. | ||
time.sleep(1.0) | ||
learner_results = self.learner_group.update_from_episodes( | ||
episodes=episodes, | ||
timesteps={ | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME: ( | ||
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) | ||
), | ||
}, | ||
) | ||
self.metrics.log_dict(learner_results, key=LEARNER_RESULTS) | ||
|
||
# Update weights - after learning on the local worker - on all | ||
# remote workers (only those RLModules that were actually trained). | ||
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): | ||
self.env_runner_group.sync_weights( | ||
from_worker_or_learner_group=self.learner_group, | ||
policies=set(learner_results.keys()) - {ALL_MODULES}, | ||
policies=list(set(learner_results.keys()) - {ALL_MODULES}), | ||
inference_only=True, | ||
) | ||
|
||
|
@@ -2492,29 +2455,6 @@ def add_policy( | |
module_spec=module_spec, | ||
) | ||
|
||
# If Learner API is enabled, we need to also add the underlying module | ||
# to the learner group. | ||
if add_to_learners and self.config.enable_rl_module_and_learner: | ||
policy = self.get_policy(policy_id) | ||
module = policy.model | ||
self.learner_group.add_module( | ||
module_id=policy_id, | ||
module_spec=RLModuleSpec.from_module(module), | ||
) | ||
|
||
# Update each Learner's `policies_to_train` information, but only | ||
# if the arg is explicitly provided here. | ||
if policies_to_train is not None: | ||
self.learner_group.foreach_learner( | ||
func=lambda learner: learner.config.multi_agent( | ||
policies_to_train=policies_to_train | ||
), | ||
timeout_seconds=0.0, # fire-and-forget | ||
) | ||
|
||
weights = policy.get_weights() | ||
self.learner_group.set_weights({policy_id: weights}) | ||
|
||
# Add to evaluation workers, if necessary. | ||
if add_to_eval_env_runners is True and self.eval_env_runner_group is not None: | ||
self.eval_env_runner_group.add_policy( | ||
|
@@ -2598,20 +2538,6 @@ def fn(worker): | |
if remove_from_env_runners: | ||
self.env_runner_group.foreach_worker(fn, local_env_runner=True) | ||
|
||
# Update each Learner's `policies_to_train` information, but only | ||
# if the arg is explicitly provided here. | ||
if ( | ||
remove_from_learners | ||
and self.config.enable_rl_module_and_learner | ||
and policies_to_train is not None | ||
): | ||
self.learner_group.foreach_learner( | ||
func=lambda learner: learner.config.multi_agent( | ||
policies_to_train=policies_to_train | ||
), | ||
timeout_seconds=0.0, # fire-and-forget | ||
) | ||
|
||
# Update the evaluation worker set's workers, if required. | ||
if remove_from_eval_env_runners and self.eval_env_runner_group is not None: | ||
self.eval_env_runner_group.foreach_worker(fn, local_env_runner=True) | ||
|
@@ -2705,10 +2631,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None: | |
""" | ||
# New API stack: Delegate to the `Checkpointable` implementation of | ||
# `save_to_path()`. | ||
if ( | ||
self.config.enable_rl_module_and_learner | ||
and self.config.enable_env_runner_and_connector_v2 | ||
): | ||
if self.config.enable_rl_module_and_learner: | ||
return self.save_to_path(checkpoint_dir) | ||
|
||
checkpoint_dir = pathlib.Path(checkpoint_dir) | ||
|
@@ -2770,10 +2693,7 @@ def save_checkpoint(self, checkpoint_dir: str) -> None: | |
def load_checkpoint(self, checkpoint_dir: str) -> None: | ||
# New API stack: Delegate to the `Checkpointable` implementation of | ||
# `restore_from_path()`. | ||
if ( | ||
self.config.enable_rl_module_and_learner | ||
and self.config.enable_env_runner_and_connector_v2 | ||
): | ||
if self.config.enable_rl_module_and_learner: | ||
self.restore_from_path(checkpoint_dir) | ||
|
||
# Call the `on_checkpoint_loaded` callback. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: we don 't need these double checks anymore b/c the hybrid stack has been deprecated already (users will get error message in the
config.validate()
call).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's remove it!