Skip to content

Commit

Permalink
Defined a new autoregressive actions environment in which the agent c…
Browse files Browse the repository at this point in the history
…annot only watch the state but needs to also watch the first action. Furthermore, implemented the 'ValueFunctionAPI' in the 'AutoregressiveActionsRLM' and ran some tests.

Signed-off-by: simonsays1980 <[email protected]>
  • Loading branch information
simonsays1980 committed Sep 27, 2024
1 parent 80c2a42 commit 330594e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 55 deletions.
77 changes: 51 additions & 26 deletions rllib/examples/envs/classes/correlated_actions_env.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,69 @@
import gymnasium as gym
from gymnasium.spaces import Box, Discrete, Tuple
import numpy as np
import random


class CorrelatedActionsEnv(gym.Env):
"""
Simple env in which the policy has to emit a tuple of equal actions.
class AutoRegressiveActionEnv(gym.Env):
"""Custom Environment with autoregressive continuous actions.
Simple env in which the policy has to emit a tuple of correlated actions.
In each step, the agent observes a random number (0 or 1) and has to choose
two actions a1 and a2.
It gets +5 reward for matching a1 to the random obs and +5 for matching a2
to a1. I.e., +10 at most per step.
In each step, the agent observes a random number (between -1 and 1) and has
to choose two actions a1 and a2.
It gets 0 reward for matching a2 to the random obs times action a1. In all
other cases the negative deviance between the desired action a2 and its
actual counterpart serves as reward.
One way to effectively learn this is through correlated action
distributions, e.g., in examples/rl_modules/autoregressive_action_rlm.py
There are 20 steps. Hence, the best score would be ~200 reward.
The game ends after the first step.
"""

def __init__(self, _=None):
self.observation_space = Box(0, 1, shape=(1,), dtype=np.float32)

# Define the action space (two continuous actions a1, a2)
self.action_space = Tuple([Discrete(2), Discrete(2)])
self.last_observation = None

def reset(self, *, seed=None, options=None):
self.t = 0
self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32)
return self.last_observation, {}
# Define the observation space (state is a single continuous value)
self.observation_space = Box(low=-1, high=1, shape=(1,), dtype=np.float32)

# Internal state for the environment (e.g., could represent a factor
# influencing the relationship)
self.state = None

def reset(self, seed=None):
"""Reset the environment to an initial state."""
super().reset(seed=seed)

# Randomly initialize the state between -1 and 1
self.state = np.random.uniform(-1, 1, size=(1,))

return self.state, {}

def step(self, action):
self.t += 1
"""Apply the autoregressive action and return step information."""

# Extract actions
a1, a2 = action
reward = 0
# Encourage correlation between most recent observation and a1.
if a1 == self.last_observation:
reward += 5
# Encourage correlation between a1 and a2.
if a1 == a2:
reward += 5
done = truncated = self.t > 20
self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32)
return self.last_observation, reward, done, truncated, {}

# The state determines the desired relationship between a1 and a2
desired_a2 = (
self.state[0] * a1
) # Autoregressive relationship dependent on state

# Reward is based on how close a2 is to the state-dependent autoregressive
# relationship
reward = -np.abs(a2 - desired_a2) # Negative absolute error as the reward

# Optionally: add some noise or complexity to the reward function
# reward += np.random.normal(0, 0.01) # Small noise can be added

# Terminate after each step (no episode length in this simple example)
done = True

# Empty info dictionary
info = {}

return self.state, reward, done, False, info
10 changes: 6 additions & 4 deletions rllib/examples/rl_modules/autoregressive_actions_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- Uses this `RLModule` in a PPO training run on a simple environment
that rewards synchronized actions.
- Stops the training after 100k steps or when the mean episode return
exceeds 150 in evaluation, i.e. if the agent has learned to
exceeds 0.01 in evaluation, i.e. if the agent has learned to
synchronize its actions.
How to run this script
Expand Down Expand Up @@ -42,7 +42,9 @@
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.envs.classes.correlated_actions_env import CorrelatedActionsEnv
from ray.rllib.examples.envs.classes.correlated_actions_env import (
AutoRegressiveActionEnv,
)
from ray.rllib.examples.rl_modules.classes.autoregressive_actions_rlm import (
AutoregressiveActionsTorchRLM,
)
Expand All @@ -59,7 +61,7 @@
from ray.tune import register_env


register_env("correlated_actions_env", lambda _: CorrelatedActionsEnv(_))
register_env("correlated_actions_env", lambda _: AutoRegressiveActionEnv(_))

parser = add_rllib_example_script_args(
default_iters=200,
Expand Down Expand Up @@ -100,7 +102,7 @@
# exceeds 150 in evaluation.
stop = {
f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 100000,
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 150.0,
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 0.012,
}

# Run the example (with Tune).
Expand Down
31 changes: 6 additions & 25 deletions rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import abc
from abc import abstractmethod
from typing import Any, Dict
from typing import Dict

from ray.rllib.core import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()


# TODO (simon): Improvements: `inference-only` mode.
class AutoregressiveActionsRLM(RLModule):
class AutoregressiveActionsRLM(RLModule, ValueFunctionAPI, abc.ABC):
"""An RLModule that implements an autoregressive action distribution.
This RLModule implements an autoregressive action distribution, where the
Expand Down Expand Up @@ -124,22 +125,6 @@ def pi(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
A dict mapping Column names to batches of policy outputs.
"""

@abstractmethod
def _compute_values(self, batch) -> Any:
"""Computes values using the vf-specific network(s) and given a batch of data.
Args:
batch: The input batch to pass through this RLModule (value function
encoder and vf-head).
Returns:
A dict mapping ModuleIDs to batches of value function outputs (already
squeezed on the last dimension (which should have shape (1,) b/c of the
single value output node). However, for complex multi-agent settings with
shareed value networks, the output might look differently (e.g. a single
return batch without the ModuleID-based mapping).
"""


class AutoregressiveActionsTorchRLM(TorchRLModule, AutoregressiveActionsRLM):
@override(AutoregressiveActionsRLM)
Expand Down Expand Up @@ -261,12 +246,8 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:

return outs

@override(AutoregressiveActionsRLM)
def _compute_values(self, batch, device=None):
infos = batch.pop(Columns.INFOS, None)
batch = convert_to_torch_tensor(batch, device=device)
if infos is not None:
batch[Columns.INFOS] = infos
@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, TensorType]):

# Encoder forward pass.
encoder_outs = self.encoder(batch)[ENCODER_OUT]
Expand Down

0 comments on commit 330594e

Please sign in to comment.