diff --git a/rllib/examples/envs/classes/correlated_actions_env.py b/rllib/examples/envs/classes/correlated_actions_env.py index a3db0d7556d9..fb0ab84ae458 100644 --- a/rllib/examples/envs/classes/correlated_actions_env.py +++ b/rllib/examples/envs/classes/correlated_actions_env.py @@ -1,44 +1,69 @@ import gymnasium as gym from gymnasium.spaces import Box, Discrete, Tuple import numpy as np -import random -class CorrelatedActionsEnv(gym.Env): - """ - Simple env in which the policy has to emit a tuple of equal actions. +class AutoRegressiveActionEnv(gym.Env): + """Custom Environment with autoregressive continuous actions. + + Simple env in which the policy has to emit a tuple of correlated actions. - In each step, the agent observes a random number (0 or 1) and has to choose - two actions a1 and a2. - It gets +5 reward for matching a1 to the random obs and +5 for matching a2 - to a1. I.e., +10 at most per step. + In each step, the agent observes a random number (between -1 and 1) and has + to choose two actions a1 and a2. + + It gets 0 reward for matching a2 to the random obs times action a1. In all + other cases the negative deviance between the desired action a2 and its + actual counterpart serves as reward. One way to effectively learn this is through correlated action distributions, e.g., in examples/rl_modules/autoregressive_action_rlm.py - There are 20 steps. Hence, the best score would be ~200 reward. + The game ends after the first step. """ def __init__(self, _=None): - self.observation_space = Box(0, 1, shape=(1,), dtype=np.float32) + + # Define the action space (two continuous actions a1, a2) self.action_space = Tuple([Discrete(2), Discrete(2)]) - self.last_observation = None - def reset(self, *, seed=None, options=None): - self.t = 0 - self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32) - return self.last_observation, {} + # Define the observation space (state is a single continuous value) + self.observation_space = Box(low=-1, high=1, shape=(1,), dtype=np.float32) + + # Internal state for the environment (e.g., could represent a factor + # influencing the relationship) + self.state = None + + def reset(self, seed=None): + """Reset the environment to an initial state.""" + super().reset(seed=seed) + + # Randomly initialize the state between -1 and 1 + self.state = np.random.uniform(-1, 1, size=(1,)) + + return self.state, {} def step(self, action): - self.t += 1 + """Apply the autoregressive action and return step information.""" + + # Extract actions a1, a2 = action - reward = 0 - # Encourage correlation between most recent observation and a1. - if a1 == self.last_observation: - reward += 5 - # Encourage correlation between a1 and a2. - if a1 == a2: - reward += 5 - done = truncated = self.t > 20 - self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32) - return self.last_observation, reward, done, truncated, {} + + # The state determines the desired relationship between a1 and a2 + desired_a2 = ( + self.state[0] * a1 + ) # Autoregressive relationship dependent on state + + # Reward is based on how close a2 is to the state-dependent autoregressive + # relationship + reward = -np.abs(a2 - desired_a2) # Negative absolute error as the reward + + # Optionally: add some noise or complexity to the reward function + # reward += np.random.normal(0, 0.01) # Small noise can be added + + # Terminate after each step (no episode length in this simple example) + done = True + + # Empty info dictionary + info = {} + + return self.state, reward, done, False, info diff --git a/rllib/examples/rl_modules/autoregressive_actions_rl_module.py b/rllib/examples/rl_modules/autoregressive_actions_rl_module.py index 0f0c8510f9ad..914ed3b364ba 100644 --- a/rllib/examples/rl_modules/autoregressive_actions_rl_module.py +++ b/rllib/examples/rl_modules/autoregressive_actions_rl_module.py @@ -10,7 +10,7 @@ - Uses this `RLModule` in a PPO training run on a simple environment that rewards synchronized actions. - Stops the training after 100k steps or when the mean episode return - exceeds 150 in evaluation, i.e. if the agent has learned to + exceeds 0.01 in evaluation, i.e. if the agent has learned to synchronize its actions. How to run this script @@ -42,7 +42,9 @@ from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.examples.envs.classes.correlated_actions_env import CorrelatedActionsEnv +from ray.rllib.examples.envs.classes.correlated_actions_env import ( + AutoRegressiveActionEnv, +) from ray.rllib.examples.rl_modules.classes.autoregressive_actions_rlm import ( AutoregressiveActionsTorchRLM, ) @@ -59,7 +61,7 @@ from ray.tune import register_env -register_env("correlated_actions_env", lambda _: CorrelatedActionsEnv(_)) +register_env("correlated_actions_env", lambda _: AutoRegressiveActionEnv(_)) parser = add_rllib_example_script_args( default_iters=200, @@ -100,7 +102,7 @@ # exceeds 150 in evaluation. stop = { f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 100000, - f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 150.0, + f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 0.012, } # Run the example (with Tune). diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index d6941aadc10c..d0ff7650a166 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -1,10 +1,12 @@ +import abc from abc import abstractmethod -from typing import Any, Dict +from typing import Dict from ray.rllib.core import Columns from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.models.configs import MLPHeadConfig from ray.rllib.core.models.specs.specs_dict import SpecDict +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.utils.annotations import ( @@ -12,14 +14,13 @@ OverrideToImplementCustomLogic_CallToSuperRecommended, ) from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() # TODO (simon): Improvements: `inference-only` mode. -class AutoregressiveActionsRLM(RLModule): +class AutoregressiveActionsRLM(RLModule, ValueFunctionAPI, abc.ABC): """An RLModule that implements an autoregressive action distribution. This RLModule implements an autoregressive action distribution, where the @@ -124,22 +125,6 @@ def pi(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: A dict mapping Column names to batches of policy outputs. """ - @abstractmethod - def _compute_values(self, batch) -> Any: - """Computes values using the vf-specific network(s) and given a batch of data. - - Args: - batch: The input batch to pass through this RLModule (value function - encoder and vf-head). - - Returns: - A dict mapping ModuleIDs to batches of value function outputs (already - squeezed on the last dimension (which should have shape (1,) b/c of the - single value output node). However, for complex multi-agent settings with - shareed value networks, the output might look differently (e.g. a single - return batch without the ModuleID-based mapping). - """ - class AutoregressiveActionsTorchRLM(TorchRLModule, AutoregressiveActionsRLM): @override(AutoregressiveActionsRLM) @@ -261,12 +246,8 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: return outs - @override(AutoregressiveActionsRLM) - def _compute_values(self, batch, device=None): - infos = batch.pop(Columns.INFOS, None) - batch = convert_to_torch_tensor(batch, device=device) - if infos is not None: - batch[Columns.INFOS] = infos + @override(ValueFunctionAPI) + def compute_values(self, batch: Dict[str, TensorType]): # Encoder forward pass. encoder_outs = self.encoder(batch)[ENCODER_OUT]