From 9b55915c38ae0c51089cbf2047f99429f43e6354 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 30 Sep 2024 15:29:25 +0200 Subject: [PATCH] [RLlib] Update autoregressive actions example. (#47829) Signed-off-by: ujjawal-khare --- rllib/BUILD | 50 +++++++++++++++++++ rllib/examples/autoregressive_action_dist.py | 4 +- .../classes/autoregressive_actions_rlm.py | 12 +++-- 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 22dff4a00c3f..393c105fa4c8 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -3263,6 +3263,56 @@ py_test( args = ["--as-test", "--framework=torch", "--stop-reward=-0.012", "--num-cpus=4"] ) +#@OldAPIStack +py_test( + name = "examples/cartpole_lstm_impala_tf2", + main = "examples/cartpole_lstm.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/cartpole_lstm.py"], + args = ["--run=IMPALA", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"] +) + +#@OldAPIStack +py_test( + name = "examples/cartpole_lstm_impala_torch", + main = "examples/cartpole_lstm.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/cartpole_lstm.py"], + args = ["--run=IMPALA", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"] +) + +#@OldAPIStack +py_test( + name = "examples/cartpole_lstm_ppo_tf2", + main = "examples/cartpole_lstm.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "large", + srcs = ["examples/cartpole_lstm.py"], + args = ["--run=PPO", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"] +) + +#@OldAPIStack +py_test( + name = "examples/cartpole_lstm_ppo_torch", + main = "examples/cartpole_lstm.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/cartpole_lstm.py"], + args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"] +) + +#@OldAPIStack +py_test( + name = "examples/cartpole_lstm_ppo_torch_with_prev_a_and_r", + main = "examples/cartpole_lstm.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/cartpole_lstm.py"], + args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4", "--use-prev-action", "--use-prev-reward"] +) + #@OldAPIStack py_test( name = "examples/centralized_critic_tf", diff --git a/rllib/examples/autoregressive_action_dist.py b/rllib/examples/autoregressive_action_dist.py index b9e7e330ccb8..5dfac509e580 100644 --- a/rllib/examples/autoregressive_action_dist.py +++ b/rllib/examples/autoregressive_action_dist.py @@ -147,7 +147,9 @@ def get_cli_args(): config = ( get_trainable_cls(args.run) .get_default_config() - .environment(CorrelatedActionsEnv) + # Batch-norm models have not been migrated to the RL Module API yet. + .api_stack(enable_rl_module_and_learner=False) + .environment(AutoRegressiveActionEnv) .framework(args.framework) .training(gamma=0.5) # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index 83aca0acc500..2b683758eb8b 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -1,10 +1,11 @@ import abc from abc import abstractmethod -from typing import Any, Dict +from typing import Dict from ray.rllib.core import Columns from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.models.configs import MLPHeadConfig +from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule @@ -233,10 +234,11 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: return outs @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, TensorType], embeddings=None): - # Encoder forward pass to get `embeddings`, if necessary. - if embeddings is None: - embeddings = self.encoder(batch)[ENCODER_OUT] + def compute_values(self, batch: Dict[str, TensorType]): + + # Encoder forward pass. + encoder_outs = self.encoder(batch)[ENCODER_OUT] + # Value head forward pass. vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head).