diff --git a/rllib/BUILD b/rllib/BUILD index e681e4c252a0..3647830a52bf 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1867,6 +1867,7 @@ py_test( srcs = ["offline/tests/test_offline_data.py"], data = [ "tests/data/cartpole/cartpole-v1_large", + "tests/data/cartpole/large.json", ], ) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 41cc4874e25c..f61f5cdffd76 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -437,6 +437,7 @@ def __init__(self, algo_class: Optional[type] = None): self.input_read_method_kwargs = {} self.input_read_schema = {} self.input_read_episodes = False + self.input_read_sample_batches = False self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS] self.input_spaces_jsonable = True self.map_batches_kwargs = {} @@ -2403,6 +2404,7 @@ def offline_data( input_read_method_kwargs: Optional[Dict] = NotProvided, input_read_schema: Optional[Dict[str, str]] = NotProvided, input_read_episodes: Optional[bool] = NotProvided, + input_read_sample_batches: Optional[bool] = NotProvided, input_compress_columns: Optional[List[str]] = NotProvided, map_batches_kwargs: Optional[Dict] = NotProvided, iter_batches_kwargs: Optional[Dict] = NotProvided, @@ -2464,8 +2466,19 @@ def offline_data( inside of RLlib's schema. The other format is a columnar format and is agnostic to the RL framework used. Use the latter format, if you are unsure when to use the data or in which RL framework. The default is - to read column data, i.e. `False`. See also `output_write_episodes` - to define the output data format when recording. + to read column data, i.e. `False`. `input_read_episodes` and + `inpuit_read_sample_batches` cannot be `True` at the same time. See + also `output_write_episodes` to define the output data format when + recording. + input_read_sample_batches: Whether offline data is stored in RLlib's old + stack `SampleBatch` type. This is usually the case for older data + recorded with RLlib in JSON line format. Reading in `SampleBatch` + data needs extra transforms and might not concatenate episode chunks + contained in different `SampleBatch`es in the data. If possible avoid + to read `SampleBatch`es and convert them in a controlled form into + RLlib`s `EpisodeType`s (i.e. `SingleAgentEpisode` or + `MultiAgentEpisode`). The default is `False`. `input_read_episodes` + and `inpuit_read_sample_batches` cannot be `True` at the same time. input_compress_columns: What input columns are compressed with LZ4 in the input data. If data is stored in `RLlib`'s `SingleAgentEpisode` ( `MultiAgentEpisode` not supported, yet). Note, @@ -2565,6 +2578,8 @@ def offline_data( self.input_read_schema = input_read_schema if input_read_episodes is not NotProvided: self.input_read_episodes = input_read_episodes + if input_read_sample_batches is not NotProvided: + self.input_read_sample_batches = input_read_sample_batches if input_compress_columns is not NotProvided: self.input_compress_columns = input_compress_columns if map_batches_kwargs is not NotProvided: diff --git a/rllib/algorithms/marwil/marwil_offline_prelearner.py b/rllib/algorithms/marwil/marwil_offline_prelearner.py index 62ea8d5b1bfc..5ba39feab5ae 100644 --- a/rllib/algorithms/marwil/marwil_offline_prelearner.py +++ b/rllib/algorithms/marwil/marwil_offline_prelearner.py @@ -26,6 +26,15 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, MultiAgentBatch]: # If we directly read in episodes we just convert to list. if self.input_read_episodes: episodes = batch["item"].tolist() + # Else, if we have old stack `SampleBatch`es. + elif self.input_read_sample_batches: + episodes = OfflinePreLearner._map_sample_batch_to_episode( + self._is_multi_agent, + batch, + finalize=False, + schema=SCHEMA | self.config.input_read_schema, + input_compress_columns=self.config.input_compress_columns, + )["episodes"] # Otherwise we ap the batch to episodes. else: # Map the batch to episodes. diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index 6c8ae08ea9ca..9dd102616e75 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -4,6 +4,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core import COMPONENT_RL_MODULE +from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.offline.offline_prelearner import OfflinePreLearner from ray.rllib.utils.annotations import ( ExperimentalAPI, @@ -24,7 +25,7 @@ def __init__(self, config: AlgorithmConfig): self.path = ( config.input_ if isinstance(config.input_, list) else Path(config.input_) ) - # Use `read_json` as default data read method. + # Use `read_parquet` as default data read method. self.data_read_method = config.input_read_method # Override default arguments for the data read method. self.data_read_method_kwargs = ( @@ -72,12 +73,13 @@ def sample( # TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs # sampling later with a different batch size would need a # reinstantiation of the iterator. + self.batch_iterator = self.data.map_batches( self.prelearner_class, fn_constructor_kwargs={ "config": self.config, "learner": self.learner_handles[0], - "spaces": self.spaces["__env__"], + "spaces": self.spaces[INPUT_ENV_SPACES], }, batch_size=num_samples, **self.map_batches_kwargs, diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index d5d6bfed9659..ea7813f50d04 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -90,6 +90,7 @@ def __init__( self.config = config self.input_read_episodes = self.config.input_read_episodes + self.input_read_sample_batches = self.config.input_read_sample_batches # We need this learner to run the learner connector pipeline. # If it is a `Learner` instance, the `Learner` is local. if isinstance(learner, Learner): @@ -144,7 +145,16 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] # If we directly read in episodes we just convert to list. if self.input_read_episodes: episodes = batch["item"].tolist() - # Otherwise we ap the batch to episodes. + # Else, if we have old stack `SampleBatch`es. + elif self.input_read_sample_batches: + episodes = OfflinePreLearner._map_sample_batch_to_episode( + self._is_multi_agent, + batch, + finalize=False, + schema=SCHEMA | self.config.input_read_schema, + input_compress_columns=self.config.input_compress_columns, + )["episodes"] + # Otherwise we map the batch to episodes. else: episodes = self._map_to_episodes( self._is_multi_agent, @@ -227,7 +237,7 @@ def _should_module_be_updated(self, module_id, multi_agent_batch=None): @staticmethod def _map_to_episodes( is_multi_agent: bool, - batch: Dict[str, np.ndarray], + batch: Dict[str, Union[list, np.ndarray]], schema: Dict[str, str] = SCHEMA, finalize: bool = False, input_compress_columns: Optional[List[str]] = None, @@ -271,7 +281,7 @@ def convert(sample, space): if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. - pass + NotImplementedError else: # Build a single-agent episode with a single row of the batch. episode = SingleAgentEpisode( @@ -288,7 +298,7 @@ def convert(sample, space): unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]), observation_space, ) - if Columns.NEXT_OBS in input_compress_columns + if Columns.OBS in input_compress_columns else convert( batch[schema[Columns.NEXT_OBS]][i], observation_space ), @@ -334,7 +344,11 @@ def convert(sample, space): else v[i] ] for k, v in batch.items() - if (k not in schema and k not in schema.values()) + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) }, len_lookback_buffer=0, ) @@ -344,3 +358,111 @@ def convert(sample, space): episodes.append(episode) # Note, `map_batches` expects a `Dict` as return value. return {"episodes": episodes} + + def _map_sample_batch_to_episode( + is_multi_agent: bool, + batch: Dict[str, Union[list, np.ndarray]], + schema: Dict[str, str] = SCHEMA, + finalize: bool = False, + input_compress_columns: Optional[List[str]] = None, + ) -> Dict[str, List[EpisodeType]]: + """Maps an old stack `SampleBatch` to new stack episodes.""" + + # Set `input_compress_columns` to an empty `list` if `None`. + input_compress_columns = input_compress_columns or [] + + # TODO (simon): CHeck, if needed. It could possibly happen that a batch contains + # data from different episodes. Merging and resplitting the batch would then + # be the solution. + # Check, if batch comes actually from multiple episodes. + # episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1 + + # Define a container to collect episodes. + episodes = [] + # Loop over `SampleBatch`es in the `ray.data` batch (a dict). + for i, obs in enumerate(batch[schema[Columns.OBS]]): + + # If multi-agent we need to extract the agent ID. + # TODO (simon): Check, what happens with the module ID. + if is_multi_agent: + agent_id = ( + # The old stack uses "agent_index" instead of "agent_id". + batch[schema["agent_index"]][i][0] + if schema["agent_index"] in batch + else None + ) + else: + agent_id = None + + if is_multi_agent: + # TODO (simon): Add support for multi-agent episodes. + NotImplementedError + else: + # Unpack observations, if needed. + obs = ( + unpack_if_needed(obs.tolist()) + if schema[Columns.OBS] in input_compress_columns + else obs.tolist() + ) + # Append the last `new_obs` to get the correct length of observations. + obs.append( + unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1]) + if schema[Columns.OBS] in input_compress_columns + else batch[schema[Columns.NEXT_OBS]][i][-1] + ) + # Create a `SingleAgentEpisode`. + episode = SingleAgentEpisode( + id_=batch[schema[Columns.EPS_ID]][i][0], + agent_id=agent_id, + observations=obs, + infos=( + batch[schema[Columns.INFOS]][i] + if schema[Columns.INFOS] in batch + else [{}] * len(obs) + ), + # Actions might be (a) serialized. We unserialize them here. + actions=( + unpack_if_needed(batch[schema[Columns.ACTIONS]][i]) + if Columns.ACTIONS in input_compress_columns + else batch[schema[Columns.ACTIONS]][i] + ), + rewards=batch[schema[Columns.REWARDS]][i], + terminated=( + any(batch[schema[Columns.TERMINATEDS]][i]) + if schema[Columns.TERMINATEDS] in batch + else any(batch["dones"][i]) + ), + truncated=( + any(batch[schema[Columns.TRUNCATEDS]][i]) + if schema[Columns.TRUNCATEDS] in batch + else False + ), + # TODO (simon): Results in zero-length episodes in connector. + # t_started=batch[Columns.T if Columns.T in batch else + # "unroll_id"][i][0], + # TODO (simon): Single-dimensional columns are not supported. + # Extra model outputs might be serialized. We unserialize them here + # if needed. + # TODO (simon): Check, if we need here also reconversion from + # JSONable in case of composite spaces. + extra_model_outputs={ + k: unpack_if_needed(v[i]) + if k in input_compress_columns + else v[i] + for k, v in batch.items() + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) + }, + len_lookback_buffer=0, + ) + # Finalize, if necessary. + # TODO (simon, sven): Check, if we should convert all data to lists + # before. Right now only obs are lists. + if finalize: + episode.finalize() + episodes.append(episode) + # Note, `map_batches` expects a `Dict` as return value. + return {"episodes": episodes} diff --git a/rllib/offline/tests/test_offline_data.py b/rllib/offline/tests/test_offline_data.py index 123dbab55b15..5e67dbb2ef50 100644 --- a/rllib/offline/tests/test_offline_data.py +++ b/rllib/offline/tests/test_offline_data.py @@ -46,6 +46,26 @@ def test_offline_convert_to_episodes(self): self.assertTrue(len(episodes) == 10) self.assertTrue(isinstance(episodes[0], SingleAgentEpisode)) + def test_offline_convert_from_old_sample_batch_to_episodes(self): + + base_path = Path(__file__).parents[2] + sample_batch_data_path = base_path / "tests/data/cartpole/large.json" + config = AlgorithmConfig().offline_data( + input_=["local://" + sample_batch_data_path.as_posix()], + input_read_method="read_json", + input_read_sample_batches=True, + ) + + offline_data = OfflineData(config) + + batch = offline_data.data.take_batch(batch_size=10) + episodes = OfflinePreLearner._map_sample_batch_to_episode(False, batch)[ + "episodes" + ] + + self.assertTrue(len(episodes) == 10) + self.assertTrue(isinstance(episodes[0], SingleAgentEpisode)) + def test_sample(self): config = AlgorithmConfig().offline_data(input_=[self.data_path]) diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index b8f6101ae026..f52b655a2d18 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -1,11 +1,11 @@ from pathlib import Path +from ray.air.constants import TRAINING_ITERATION from ray.rllib.algorithms.bc import BCConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, EVALUATION_RESULTS, - TRAINING_ITERATION_TIMER, ) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -75,7 +75,7 @@ stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0, - TRAINING_ITERATION_TIMER: 350.0, + TRAINING_ITERATION: 350, } if __name__ == "__main__": diff --git a/rllib/tuned_examples/bc/pendulum_bc.py b/rllib/tuned_examples/bc/pendulum_bc.py index ffd26e471b53..6988f0989ac6 100644 --- a/rllib/tuned_examples/bc/pendulum_bc.py +++ b/rllib/tuned_examples/bc/pendulum_bc.py @@ -1,11 +1,11 @@ from pathlib import Path +from ray.air.constants import TRAINING_ITERATION from ray.rllib.algorithms.bc import BCConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, EVALUATION_RESULTS, - TRAINING_ITERATION_TIMER, ) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -61,7 +61,7 @@ stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -200.0, - TRAINING_ITERATION_TIMER: 350.0, + TRAINING_ITERATION: 350, } if __name__ == "__main__":