forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Update autoregressive actions example. (ray-project#47829)
Signed-off-by: ujjawal-khare <[email protected]>
- Loading branch information
1 parent
e966649
commit 1642d61
Showing
5 changed files
with
78 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,74 @@ | ||
import gymnasium as gym | ||
from gymnasium.spaces import Box, Discrete, Tuple | ||
import numpy as np | ||
import random | ||
from typing import Any, Dict, Optional | ||
|
||
|
||
class CorrelatedActionsEnv(gym.Env): | ||
""" | ||
Simple env in which the policy has to emit a tuple of equal actions. | ||
class AutoRegressiveActionEnv(gym.Env): | ||
"""Custom Environment with autoregressive continuous actions. | ||
Simple env in which the policy has to emit a tuple of correlated actions. | ||
In each step, the agent observes a random number (0 or 1) and has to choose | ||
two actions a1 and a2. | ||
It gets +5 reward for matching a1 to the random obs and +5 for matching a2 | ||
to a1. I.e., +10 at most per step. | ||
In each step, the agent observes a random number (between -1 and 1) and has | ||
to choose two actions a1 and a2. | ||
It gets 0 reward for matching a2 to the random obs times action a1. In all | ||
other cases the negative deviance between the desired action a2 and its | ||
actual counterpart serves as reward. The reward is constructed in such a | ||
way that actions need to be correlated to succeed. It is not possible | ||
for the network to learn each action head separately. | ||
One way to effectively learn this is through correlated action | ||
distributions, e.g., in examples/rl_modules/autoregressive_action_rlm.py | ||
There are 20 steps. Hence, the best score would be ~200 reward. | ||
The game ends after the first step. | ||
""" | ||
|
||
def __init__(self, _=None): | ||
self.observation_space = Box(0, 1, shape=(1,), dtype=np.float32) | ||
|
||
# Define the action space (two continuous actions a1, a2) | ||
self.action_space = Tuple([Discrete(2), Discrete(2)]) | ||
self.last_observation = None | ||
|
||
def reset(self, *, seed=None, options=None): | ||
self.t = 0 | ||
self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32) | ||
return self.last_observation, {} | ||
# Define the observation space (state is a single continuous value) | ||
self.observation_space = Box(low=-1, high=1, shape=(1,), dtype=np.float32) | ||
|
||
# Internal state for the environment (e.g., could represent a factor | ||
# influencing the relationship) | ||
self.state = None | ||
|
||
def reset( | ||
self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None | ||
): | ||
"""Reset the environment to an initial state.""" | ||
super().reset(seed=seed, options=options) | ||
|
||
# Randomly initialize the state between -1 and 1 | ||
self.state = np.random.uniform(-1, 1, size=(1,)) | ||
|
||
return self.state, {} | ||
|
||
def step(self, action): | ||
self.t += 1 | ||
"""Apply the autoregressive action and return step information.""" | ||
|
||
# Extract actions | ||
a1, a2 = action | ||
reward = 0 | ||
# Encourage correlation between most recent observation and a1. | ||
if a1 == self.last_observation: | ||
reward += 5 | ||
# Encourage correlation between a1 and a2. | ||
if a1 == a2: | ||
reward += 5 | ||
done = truncated = self.t > 20 | ||
self.last_observation = np.array([random.choice([0, 1])], dtype=np.float32) | ||
return self.last_observation, reward, done, truncated, {} | ||
|
||
# The state determines the desired relationship between a1 and a2 | ||
desired_a2 = ( | ||
self.state[0] * a1 | ||
) # Autoregressive relationship dependent on state | ||
|
||
# Reward is based on how close a2 is to the state-dependent autoregressive | ||
# relationship | ||
reward = -np.abs(a2 - desired_a2) # Negative absolute error as the reward | ||
|
||
# Optionally: add some noise or complexity to the reward function | ||
# reward += np.random.normal(0, 0.01) # Small noise can be added | ||
|
||
# Terminate after each step (no episode length in this simple example) | ||
done = True | ||
|
||
# Empty info dictionary | ||
info = {} | ||
|
||
return self.state, reward, done, False, info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters