diff --git a/rllib/evaluation/collectors/agent_collector.py b/rllib/evaluation/collectors/agent_collector.py index 3a55a581fecc..4a6a7a320f02 100644 --- a/rllib/evaluation/collectors/agent_collector.py +++ b/rllib/evaluation/collectors/agent_collector.py @@ -10,7 +10,10 @@ from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space +from ray.rllib.utils.spaces.space_utils import ( + flatten_to_single_ndarray, + get_dummy_batch_for_space, +) from ray.rllib.utils.typing import ( EpisodeID, EnvID, @@ -261,11 +264,16 @@ def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> Non # Do not flatten infos, state_out_ and (if configured) actions. # Infos/state-outs may be structs that change from timestep to # timestep. + should_flatten_action_key = ( + k == SampleBatch.ACTIONS and not self.disable_action_flattening + ) if ( k == SampleBatch.INFOS or k.startswith("state_out_") - or (k == SampleBatch.ACTIONS and not self.disable_action_flattening) + or should_flatten_action_key ): + if should_flatten_action_key: + v = flatten_to_single_ndarray(v) self.buffers[k][0].append(v) # Flatten all other columns. else: @@ -506,11 +514,16 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: # lists. These are monolithic items (infos is a dict that # should not be further split, same for state-out items, which # could be custom dicts as well). + should_flatten_action_key = ( + col == SampleBatch.ACTIONS and not self.disable_action_flattening + ) if ( col == SampleBatch.INFOS or col.startswith("state_out_") - or (col == SampleBatch.ACTIONS and not self.disable_action_flattening) + or should_flatten_action_key ): + if should_flatten_action_key: + data = flatten_to_single_ndarray(data) self.buffers[col] = [[data for _ in range(shift)]] else: self.buffers[col] = [ diff --git a/rllib/tests/test_nested_action_spaces.py b/rllib/tests/test_nested_action_spaces.py index 465a8c8cc270..35371307e465 100644 --- a/rllib/tests/test_nested_action_spaces.py +++ b/rllib/tests/test_nested_action_spaces.py @@ -10,6 +10,7 @@ from ray.rllib.algorithms.pg import PG, DEFAULT_CONFIG from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.offline.json_reader import JsonReader +from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch from ray.rllib.utils.test_utils import framework_iterator SPACES = { @@ -76,7 +77,6 @@ def test_nested_action_spaces(self): # Remove lr schedule from config, not needed here, and not supported by BC. del config["lr_schedule"] - for _ in framework_iterator(config): for name, action_space in SPACES.items(): config["env_config"] = { @@ -97,6 +97,7 @@ def test_nested_action_spaces(self): ioctx=pg.workers.local_worker().io_context, ) sample_batch = reader.next() + sample_batch = convert_ma_batch_to_sample_batch(sample_batch) if flatten: assert isinstance(sample_batch["actions"], np.ndarray) assert len(sample_batch["actions"].shape) == 2