Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make action unwritable #24218

Closed
Closed
14 changes: 14 additions & 0 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def new_episode(env_id):
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
t4 = time.time()

base_env.send_actions(actions_to_send)
perf_stats.env_wait_time += time.time() - t4

Expand Down Expand Up @@ -1244,12 +1245,25 @@ def _process_policy_eval_results(
episode._set_last_extra_action_outs(
agent_id, {k: v[i] for k, v in extra_action_out_cols.items()}
)

if env_id in off_policy_actions and agent_id in off_policy_actions[env_id]:
episode._set_last_action(agent_id, off_policy_actions[env_id][agent_id])
else:
episode._set_last_action(agent_id, action)

assert agent_id not in actions_to_send[env_id]
# Flag actions as immutable to notify the user when trying to change it
# and to avoid hardly traceable errors.
def make_action_immutable(obj):
if isinstance(obj, np.ndarray):
obj.setflags(write=False)
return obj
elif isinstance(obj, dict):
from types import MappingProxyType
return MappingProxyType(obj)
else:
return obj
tree.map_structure(make_action_immutable, action_to_send)
actions_to_send[env_id][agent_id] = action_to_send

return actions_to_send
Expand Down
41 changes: 41 additions & 0 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,47 @@ def test_action_normalization(self):
self.assertGreater(np.max(sample["actions"]), action_space.high[0])
self.assertLess(np.min(sample["actions"]), action_space.low[0])
ev.stop()

def test_action_immutability(self):
from ray.rllib.examples.env.random_env import RandomEnv

action_space = gym.spaces.Box(0.0001, 0.0002, (5,))
class ActionMutationEnv(RandomEnv):
def init(self, config):
self.test_case = config["test_case"]
super().__init__(config=config)

def step(self, action):
# Ensure that it is called from inside the sampling process.
import inspect
curframe = inspect.currentframe()
called_from_check = any([frame[3] == "check_gym_environments" \
for frame in inspect.getouterframes(curframe, 2)])
# Check, whether the action is immutable.
if action.flags.writeable and not called_from_check:
self.test_case.assertFalse(action.flags.writeable, "Action is mutable")
return super().step(action)

ev = RolloutWorker(
env_creator=lambda _: ActionMutationEnv(
config=dict(
test_case=self,
action_space=action_space,
max_episode_len=10,
p_done=0.0,
check_action_bounds=True,
)
),
policy_spec=RandomPolicy,
policy_config=dict(
action_space=action_space,
ignore_action_bounds=True,
),
clip_actions=False,
batch_mode="complete_episodes",
)
sample = ev.sample()
ev.stop()

def test_reward_clipping(self):
# Clipping: True (clip between -1.0 and 1.0).
Expand Down