diff --git a/rllib/examples/envs/classes/correlated_actions_env.py b/rllib/examples/envs/classes/correlated_actions_env.py index 055a4d75d558..a3db0d7556d9 100644 --- a/rllib/examples/envs/classes/correlated_actions_env.py +++ b/rllib/examples/envs/classes/correlated_actions_env.py @@ -1,5 +1,6 @@ import gymnasium as gym -from gymnasium.spaces import Discrete, Tuple +from gymnasium.spaces import Box, Discrete, Tuple +import numpy as np import random @@ -13,19 +14,19 @@ class CorrelatedActionsEnv(gym.Env): to a1. I.e., +10 at most per step. One way to effectively learn this is through correlated action - distributions, e.g., in examples/autoregressive_action_dist.py + distributions, e.g., in examples/rl_modules/autoregressive_action_rlm.py There are 20 steps. Hence, the best score would be ~200 reward. """ - def __init__(self, _): - self.observation_space = Discrete(2) + def __init__(self, _=None): + self.observation_space = Box(0, 1, shape=(1,), dtype=np.float32) self.action_space = Tuple([Discrete(2), Discrete(2)]) self.last_observation = None def reset(self, *, seed=None, options=None): self.t = 0 - self.last_observation = random.choice([0, 1]) + self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32) return self.last_observation, {} def step(self, action): @@ -39,5 +40,5 @@ def step(self, action): if a1 == a2: reward += 5 done = truncated = self.t > 20 - self.last_observation = random.choice([0, 1]) + self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32) return self.last_observation, reward, done, truncated, {} diff --git a/rllib/examples/rl_modules/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/autoregressive_actions_rlm.py new file mode 100644 index 000000000000..7920ca622738 --- /dev/null +++ b/rllib/examples/rl_modules/autoregressive_actions_rlm.py @@ -0,0 +1,107 @@ +"""An example script showing how to define and load an `RLModule` with +a dependent action space. + +This examples: + - Defines an `RLModule` with autoregressive actions. + - It does so by implementing a prior distribution for the first couple + of actions and then using these actions in a posterior distribution. + - Furthermore, it uses in the `RLModule` our simple base `Catalog` class + to build the distributions. + - Uses this `RLModule` in a PPO training run on a simple environment + that rewards synchronized actions. + - Stops the training after 100k steps or when the mean episode return + exceeds 150 in evaluation, i.e. if the agent has learned to + synchronize its actions. + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --num-env-runners 2` + +Control the number of `EnvRunner`s with the `--num-env-runners` flag. This +will increase the sampling speed. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +You should expect a reward of around 155-160 after ~36,000 timesteps sampled +(trained) being achieved by a simple PPO policy (no tuning, just using RLlib's +default settings). For details take also a closer look into the +`CorrelatedActionsEnv` environment. Rewards are such that to receive a return +over 100, the agent must learn to synchronize its actions. +""" + + +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.envs.classes.correlated_actions_env import CorrelatedActionsEnv +from ray.rllib.examples.rl_modules.classes.autoregressive_actions_rlm import ( + AutoregressiveActionTorchRLM, +) +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + EVALUATION_RESULTS, + NUM_ENV_STEPS_SAMPLED_LIFETIME, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune import register_env + + +register_env("correlated_actions_env", lambda _: CorrelatedActionsEnv(_)) + +parser = add_rllib_example_script_args( + default_iters=200, + default_timesteps=100000, + default_reward=150.0, +) + +if __name__ == "__main__": + args = parser.parse_args() + + if args.algo != "PPO": + raise ValueError("This example only supports PPO. Please use --algo=PPO.") + + base_config = ( + PPOConfig() + .environment(env="correlated_actions_env") + .rl_module( + model_config_dict={ + "post_fcnet_hiddens": [64, 64], + "post_fcnet_activation": "relu", + }, + # We need to explicitly specify here RLModule to use and + # the catalog needed to build it. + rl_module_spec=SingleAgentRLModuleSpec( + module_class=AutoregressiveActionTorchRLM, + catalog_class=Catalog, + ), + ) + .evaluation( + evaluation_num_env_runners=1, + evaluation_interval=1, + # Run evaluation parallel to training to speed up the example. + evaluation_parallel_to_training=True, + ) + ) + + # Let's stop the training after 100k steps or when the mean episode return + # exceeds 150 in evaluation. + stop = { + f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 100000, + f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 150.0, + } + + # Run the example (with Tune). + run_rllib_example_script_experiment(base_config, args, stop=stop) diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py new file mode 100644 index 000000000000..fa773944976b --- /dev/null +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -0,0 +1,295 @@ +from abc import abstractmethod +from typing import Any, Dict, Type + +from ray.rllib.core import Columns +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.models.configs import MLPHeadConfig +from ray.rllib.core.models.specs.specs_dict import SpecDict +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule +from ray.rllib.models.distributions import Distribution +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +# TODO (simon): Improvements: `inference-only` mode. +class AutoregressiveActionRLM(RLModule): + """An RLModule that implements an autoregressive action distribution. + + This RLModule implements an autoregressive action distribution, where the + action is sampled in two steps. First, the prior action is sampled from a + prior distribution. Then, the posterior action is sampled from a posterior + distribution that depends on the prior action and the input data. The prior + and posterior distributions are implemented as MLPs. + + The following components are implemented: + - ENCODER: An encoder that processes the observations from the environment. + - PI: A Policy head that outputs the actions, the log probabilities of the + actions, and the input to the action distribution. This head is composed + of two sub-heads: + - A prior head that outputs the logits for the prior action distribution. + - A posterior head that outputs the logits for the posterior action + distribution. + - A value function head that outputs the value function. + + Note, this RLModule is implemented for the `PPO` algorithm only. It is not + guaranteed to work with other algorithms. + """ + + @override(RLModule) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def setup(self): + super().setup() + + # Build the encoder. + self.encoder = self.config.get_catalog().build_encoder(framework=self.framework) + + # Build the prior and posterior heads. + # Note, the action space is a Tuple space. + self.action_dist_cls = self.config.get_catalog().get_action_dist_cls( + self.framework + ) + # Note further, we neet to know the required input dimensions for + # the partial distributions. + self.required_output_dims = self.action_dist_cls.required_input_dim( + space=self.config.action_space, + as_list=True, + ) + action_dims = self.config.action_space[0].shape or (1,) + latent_dims = self.config.get_catalog().latent_dims + prior_config = MLPHeadConfig( + # Use the hidden dimension from the encoder output. + input_dims=latent_dims, + # Use configurations from the `model_config_dict`. + hidden_layer_dims=self.config.model_config_dict["post_fcnet_hiddens"], + hidden_layer_activation=self.config.model_config_dict[ + "post_fcnet_activation" + ], + output_layer_dim=self.required_output_dims[0], + output_layer_activation="linear", + ) + # Build the posterior head. + posterior_config = MLPHeadConfig( + input_dims=(latent_dims[0] + action_dims[0],), + hidden_layer_dims=self.config.model_config_dict["post_fcnet_hiddens"], + hidden_layer_activation=self.config.model_config_dict[ + "post_fcnet_activation" + ], + output_layer_dim=self.required_output_dims[1], + output_layer_activation="linear", + ) + + self.prior = prior_config.build(framework=self.framework) + self.posterior = posterior_config.build(framework=self.framework) + + # Build the value function head. + vf_config = MLPHeadConfig( + input_dims=latent_dims, + hidden_layer_dims=self.config.model_config_dict["post_fcnet_hiddens"], + hidden_layer_activation=self.config.model_config_dict[ + "post_fcnet_activation" + ], + output_layer_dim=1, + output_layer_activation="linear", + ) + self.vf = vf_config.build(framework=self.framework) + + @override(RLModule) + def get_train_action_dist_cls(self) -> Type[Distribution]: + return self.action_dist_cls + + @override(RLModule) + def get_exploration_action_dist_cls(self) -> Type[Distribution]: + return self.action_dist_cls + + @override(RLModule) + def get_inference_action_dist_cls(self) -> Type[Distribution]: + return self.action_dist_cls + + @override(RLModule) + def output_specs_inference(self) -> SpecDict: + return [Columns.ACTIONS] + + @override(RLModule) + def output_specs_exploration(self) -> SpecDict: + return [Columns.ACTION_DIST_INPUTS, Columns.ACTIONS, Columns.ACTION_LOGP] + + @override(RLModule) + def output_specs_train(self) -> SpecDict: + return [ + Columns.ACTION_DIST_INPUTS, + Columns.ACTIONS, + Columns.ACTION_LOGP, + Columns.VF_PREDS, + ] + + @abstractmethod + def pi(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: + """Computes the policy outputs given a batch of data. + + Args: + batch: The input batch to pass through the policy head. + + Returns: + A dict mapping Column names to batches of policy outputs. + """ + + @abstractmethod + def _compute_values(self, batch) -> Any: + """Computes values using the vf-specific network(s) and given a batch of data. + + Args: + batch: The input batch to pass through this RLModule (value function + encoder and vf-head). + + Returns: + A dict mapping ModuleIDs to batches of value function outputs (already + squeezed on the last dimension (which should have shape (1,) b/c of the + single value output node). However, for complex multi-agent settings with + shareed value networks, the output might look differently (e.g. a single + return batch without the ModuleID-based mapping). + """ + + +class AutoregressiveActionTorchRLM(TorchRLModule, AutoregressiveActionRLM): + @override(AutoregressiveActionRLM) + def pi( + self, batch: Dict[str, TensorType], inference: bool = False + ) -> Dict[str, TensorType]: + pi_outs = {} + + # Prior forward pass. + prior_out = self.prior(batch) + prior_logits = torch.cat( + [ + prior_out, + # We add zeros for the posterior logits, which we do not have at + # this point of time. + torch.zeros(size=(prior_out.shape[0], self.required_output_dims[1])), + ], + dim=-1, + ) + # Get the prior action distribution to sample the prior action. + if inference: + # If in inference mode, we need to set the distribution to be deterministic. + prior_action_dist = self.action_dist_cls.from_logits( + prior_logits + ).to_deterministic() + # If in inference mode, we can sample in a simple way. + prior_action = prior_action_dist._flat_child_distributions[0].sample() + else: + prior_action_dist = self.action_dist_cls.from_logits(prior_logits) + # Note, `TorchMultiDistribution.from_logits` does set the `logits`, but not + # the `probs` attribute. We need to set the `probs` attribute to be able to + # sample from the distribution in a differentiable way. + prior_action_dist._flat_child_distributions[0].probs = torch.softmax( + prior_out, dim=-1 + ) + prior_action_dist._flat_child_distributions[0].logits = None + # Otherwise, we need to be able to backpropagate through the prior action + # that's why we sample from the distribution using the `rsample` method. + # TODO (simon, sven): Check, if we need to return the one-hot sampled action + # instead of the real-valued one. + prior_action = torch.argmax( + prior_action_dist._flat_child_distributions[0].rsample(), + dim=-1, + ) + + # Posterior forward pass. + posterior_batch = torch.cat([batch, prior_action.view(-1, 1)], dim=-1) + posterior_out = self.posterior(posterior_batch) + # Concatenate the prior and posterior logits to get the final logits. + posterior_logits = torch.cat([prior_out, posterior_out], dim=-1) + if inference: + posterior_action_dist = self.action_dist_cls.from_logits( + posterior_logits + ).to_deterministic() + # Sample the posterior action. + posterior_action = posterior_action_dist._flat_child_distributions[ + 1 + ].sample() + + else: + # Get the posterior action distribution to sample the posterior action. + posterior_action_dist = self.action_dist_cls.from_logits(posterior_logits) + # Sample the posterior action. + posterior_action = posterior_action_dist._flat_child_distributions[ + 1 + ].sample() + + # We need the log probabilities of the sampled actions for the loss + # calculation. + prior_action_logp = prior_action_dist._flat_child_distributions[0].logp( + prior_action + ) + posterior_action_logp = posterior_action_dist._flat_child_distributions[ + 1 + ].logp(posterior_action) + pi_outs[Columns.ACTION_LOGP] = prior_action_logp + posterior_action_logp + # We also need the input to the action distribution to calculate the + # KL-divergence. + pi_outs[Columns.ACTION_DIST_INPUTS] = posterior_logits + + # Concatenate the prior and posterior actions and log probabilities. + pi_outs[Columns.ACTIONS] = (prior_action, posterior_action) + + return pi_outs + + @override(TorchRLModule) + def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: + + # Encoder forward pass. + encoder_out = self.encoder(batch) + + # Policy head forward pass. + return self.pi(encoder_out[ENCODER_OUT], inference=True) + + @override(TorchRLModule) + def _forward_exploration( + self, batch: Dict[str, TensorType], **kwargs + ) -> Dict[str, TensorType]: + # Encoder forward pass. + encoder_out = self.encoder(batch) + + # Policy head forward pass. + return self.pi(encoder_out[ENCODER_OUT], inference=False) + + @override(TorchRLModule) + def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: + + outs = {} + + # Encoder forward pass. + encoder_out = self.encoder(batch) + + # Policy head forward pass. + outs.update(self.pi(encoder_out[ENCODER_OUT])) + + # Value function head forward pass. + vf_out = self.vf(encoder_out[ENCODER_OUT]) + outs[Columns.VF_PREDS] = vf_out.squeeze(-1) + + return outs + + @override(AutoregressiveActionRLM) + def _compute_values(self, batch, device=None): + infos = batch.pop(Columns.INFOS, None) + batch = convert_to_torch_tensor(batch, device=device) + if infos is not None: + batch[Columns.INFOS] = infos + + # Encoder forward pass. + encoder_outs = self.encoder(batch)[ENCODER_OUT] + + # Value head forward pass. + vf_out = self.vf(encoder_outs) + + # Squeeze out last dimension (single node value head). + return vf_out.squeeze(-1) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index d9f42c0ec473..d5f3bccc220c 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -600,8 +600,13 @@ def sample(self): @staticmethod @override(Distribution) - def required_input_dim(space: gym.Space, input_lens: List[int], **kwargs) -> int: - return sum(input_lens) + def required_input_dim( + space: gym.Space, input_lens: List[int], as_list: bool = False, **kwargs + ) -> int: + if as_list: + return input_lens + else: + return sum(input_lens) @classmethod @override(Distribution)