Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] [Connectors] Fix test nested action spaces connectors #30459

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions rllib/evaluation/collectors/agent_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space
from ray.rllib.utils.spaces.space_utils import (
flatten_to_single_ndarray,
get_dummy_batch_for_space,
)
from ray.rllib.utils.typing import (
EpisodeID,
EnvID,
Expand Down Expand Up @@ -261,11 +264,16 @@ def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> Non
# Do not flatten infos, state_out_ and (if configured) actions.
# Infos/state-outs may be structs that change from timestep to
# timestep.
should_flatten_action_key = (
k == SampleBatch.ACTIONS and not self.disable_action_flattening
)
if (
k == SampleBatch.INFOS
or k.startswith("state_out_")
or (k == SampleBatch.ACTIONS and not self.disable_action_flattening)
or should_flatten_action_key
):
if should_flatten_action_key:
v = flatten_to_single_ndarray(v)
self.buffers[k][0].append(v)
# Flatten all other columns.
else:
Expand Down Expand Up @@ -506,11 +514,16 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
# lists. These are monolithic items (infos is a dict that
# should not be further split, same for state-out items, which
# could be custom dicts as well).
should_flatten_action_key = (
col == SampleBatch.ACTIONS and not self.disable_action_flattening
)
if (
col == SampleBatch.INFOS
or col.startswith("state_out_")
or (col == SampleBatch.ACTIONS and not self.disable_action_flattening)
or should_flatten_action_key
):
if should_flatten_action_key:
data = flatten_to_single_ndarray(data)
self.buffers[col] = [[data for _ in range(shift)]]
else:
self.buffers[col] = [
Expand Down
3 changes: 2 additions & 1 deletion rllib/tests/test_nested_action_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.rllib.algorithms.pg import PG, DEFAULT_CONFIG
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
from ray.rllib.utils.test_utils import framework_iterator

SPACES = {
Expand Down Expand Up @@ -76,7 +77,6 @@ def test_nested_action_spaces(self):

# Remove lr schedule from config, not needed here, and not supported by BC.
del config["lr_schedule"]

for _ in framework_iterator(config):
for name, action_space in SPACES.items():
config["env_config"] = {
Expand All @@ -97,6 +97,7 @@ def test_nested_action_spaces(self):
ioctx=pg.workers.local_worker().io_context,
)
sample_batch = reader.next()
sample_batch = convert_ma_batch_to_sample_batch(sample_batch)
if flatten:
assert isinstance(sample_batch["actions"], np.ndarray)
assert len(sample_batch["actions"].shape) == 2
Expand Down