Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 01 (some preparat…
Browse files Browse the repository at this point in the history
…ory cleanups). (ray-project#47884)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent ffc1746 commit 390e591
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 36 deletions.
43 changes: 13 additions & 30 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,36 +817,19 @@ def setup(self, config: AlgorithmConfig) -> None:
rl_module_ckpt_dirs=rl_module_ckpt_dirs,
)

# Only when using RolloutWorkers: Update also the worker set's
# `is_policy_to_train`.
# Note that with the new EnvRunner API in combination with the new stack,
# this information only needs to be kept in the Learner and not on the
# EnvRunners anymore.
if not self.config.enable_env_runner_and_connector_v2:
policies_to_train = self.config.policies_to_train or set(
self.config.policies
)
self.env_runner_group.foreach_worker(
lambda w: w.set_is_policy_to_train(policies_to_train),
)
# Sync the weights from the learner group to the rollout workers.
self.env_runner.set_weights(self.learner_group.get_weights())
self.env_runner_group.sync_weights(inference_only=True)
# New stack/EnvRunner APIs: Use get/set_state.
else:
# Sync the weights from the learner group to the EnvRunners.
rl_module_state = self.learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
inference_only=True,
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state})
self.env_runner_group.sync_env_runner_states(
config=self.config,
env_steps_sampled=self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0
),
rl_module_state=rl_module_state,
)
# Sync the weights from the learner group to the EnvRunners.
rl_module_state = self.learner_group.get_state(
components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
inference_only=True,
)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
self.env_runner.set_state({COMPONENT_RL_MODULE: rl_module_state})
self.env_runner_group.sync_env_runner_states(
config=self.config,
env_steps_sampled=self.metrics.peek(
NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0
),
rl_module_state=rl_module_state,
)

if self.offline_data:
# If the learners are remote we need to provide specific
Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/bc/cartpole_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
)

stop = {
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0,
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0,
TRAINING_ITERATION: 350,
}

Expand Down
6 changes: 1 addition & 5 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import argparse
from collections import Counter
import copy
import gymnasium as gym
from gymnasium.spaces import Box, Discrete, MultiDiscrete, MultiBinary
from gymnasium.spaces import Dict as GymDict
from gymnasium.spaces import Tuple as GymTuple
import json
import logging
import os
Expand Down Expand Up @@ -36,7 +32,7 @@
from ray.air.constants import TRAINING_ITERATION
from ray.air.integrations.wandb import WandbLoggerCallback, WANDB_ENV_VAR
from ray.rllib.common import SupportedFileType
from ray.rllib.core import DEFAULT_MODULE_ID, Columns
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind
from ray.rllib.train import load_experiments_from_file
from ray.rllib.utils.annotations import OldAPIStack
Expand Down

0 comments on commit 390e591

Please sign in to comment.