From 15b3ea7fe2743cfc120f9ec1ba330a9eab08b419 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Sat, 24 Aug 2024 11:53:44 +0200 Subject: [PATCH 1/6] Started to write converter from old 'SampleBatch' to new 'Episodes' in offline data recorded with the old stack. Signed-off-by: simonsays1980 --- rllib/algorithms/algorithm_config.py | 19 +++- rllib/offline/offline_data.py | 53 +++++++++--- rllib/offline/offline_prelearner.py | 115 ++++++++++++++++++++++++- rllib/tuned_examples/bc/cartpole_bc.py | 26 +++++- 4 files changed, 192 insertions(+), 21 deletions(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 0648f236edda..f3e976105ac8 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -435,6 +435,7 @@ def __init__(self, algo_class: Optional[type] = None): self.input_read_method_kwargs = {} self.input_read_schema = {} self.input_read_episodes = False + self.input_read_sample_batches = False self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS] self.input_spaces_jsonable = True self.map_batches_kwargs = {} @@ -2392,6 +2393,7 @@ def offline_data( input_read_method_kwargs: Optional[Dict] = NotProvided, input_read_schema: Optional[Dict[str, str]] = NotProvided, input_read_episodes: Optional[bool] = NotProvided, + input_read_sample_batches: Optional[bool] = NotProvided, input_compress_columns: Optional[List[str]] = NotProvided, map_batches_kwargs: Optional[Dict] = NotProvided, iter_batches_kwargs: Optional[Dict] = NotProvided, @@ -2453,8 +2455,19 @@ def offline_data( inside of RLlib's schema. The other format is a columnar format and is agnostic to the RL framework used. Use the latter format, if you are unsure when to use the data or in which RL framework. The default is - to read column data, i.e. `False`. See also `output_write_episodes` - to define the output data format when recording. + to read column data, i.e. `False`. input_read_episodes` and + `inpuit_read_sample_batches` cannot be `True` at the same time.See + also `output_write_episodes`to define the output data format when + recording. + input_read_sample_batches: Whether offline data is stored in RLlib's old + stack `SampleBatch` type. This is usually the case for older data + recorded with RLlib in JSON line format. Reading in `SampleBatch` + data needs extra transforms and might not concatenate episode chunks + contained in different `SampleBatch`es in the data. If possible avoid + to read `SampleBatch`es and convert them in a controlled form into + RLlib`s `EpisodeType`s (i.e. `SingleAgentEpisode` or + `MultiAgentEpisode`). The default is `False`. `input_read_episodes` + and `inpuit_read_sample_batches` cannot be `True` at the same time. input_compress_columns: What input columns are compressed with LZ4 in the input data. If data is stored in `RLlib`'s `SingleAgentEpisode` ( `MultiAgentEpisode` not supported, yet). Note, @@ -2554,6 +2567,8 @@ def offline_data( self.input_read_schema = input_read_schema if input_read_episodes is not NotProvided: self.input_read_episodes = input_read_episodes + if input_read_sample_batches is not NotProvided: + self.input_read_sample_batches = input_read_sample_batches if input_compress_columns is not NotProvided: self.input_compress_columns = input_compress_columns if map_batches_kwargs is not NotProvided: diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index 6c8ae08ea9ca..5a781c6b63bb 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -4,6 +4,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core import COMPONENT_RL_MODULE +from ray.rllib.env import INPUT_ENV_SPACES from ray.rllib.offline.offline_prelearner import OfflinePreLearner from ray.rllib.utils.annotations import ( ExperimentalAPI, @@ -41,9 +42,19 @@ def __init__(self, config: AlgorithmConfig): logger.error(e) # Avoids reinstantiating the batch iterator each time we sample. self.batch_iterator = None + self.map_method = ( + "map" + if self.config.input_read_episodes or self.config.input_read_sample_batches + else "map_batches" + ) self.map_batches_kwargs = ( self.default_map_batches_kwargs | self.config.map_batches_kwargs ) + self.iter_method = ( + "iter_rows" + if self.config.input_read_episodes or self.config.input_read_sample_batches + else "iter_batches" + ) self.iter_batches_kwargs = ( self.default_iter_batches_kwargs | self.config.iter_batches_kwargs ) @@ -72,19 +83,35 @@ def sample( # TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs # sampling later with a different batch size would need a # reinstantiation of the iterator. - self.batch_iterator = self.data.map_batches( - self.prelearner_class, - fn_constructor_kwargs={ - "config": self.config, - "learner": self.learner_handles[0], - "spaces": self.spaces["__env__"], - }, - batch_size=num_samples, - **self.map_batches_kwargs, - ).iter_batches( - batch_size=num_samples, - **self.iter_batches_kwargs, - ) + if self.config.input_read_sample_batches: + self.batch_iterator = self.data.map( + self.prelearner_class, + fn_constructor_kwargs={ + "config": self.config, + "learner": self.learner_handles[0], + "spaces": self.spaces[INPUT_ENV_SPACES], + }, + concurrency=1, + # batch_size=num_samples, + # **self.map_batches_kwargs, + ).iter_batches( + batch_size=num_samples, + # **self.iter_batches_kwargs, + ) + else: + self.batch_iterator = self.data.map_batches( + self.prelearner_class, + fn_constructor_kwargs={ + "config": self.config, + "learner": self.learner_handles[0], + "spaces": self.spaces[INPUT_ENV_SPACES], + }, + batch_size=num_samples, + **self.map_batches_kwargs, + ).iter_batches( + batch_size=num_samples, + **self.iter_batches_kwargs, + ) # Do we want to return an iterator or a single batch? if return_iterator: diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index 743e80d644c5..037d5003e10f 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -90,6 +90,7 @@ def __init__( self.config = config self.input_read_episodes = self.config.input_read_episodes + self.input_read_sample_batches = self.config.input_read_sample_batches # We need this learner to run the learner connector pipeline. # If it is a `Learner` instance, the `Learner` is local. if isinstance(learner, Learner): @@ -144,7 +145,16 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] # If we directly read in episodes we just convert to list. if self.input_read_episodes: episodes = batch["item"].tolist() - # Otherwise we ap the batch to episodes. + # Else, if we have old stack `SampleBatch`es. + elif self.input_read_sample_batches: + episodes = OfflinePreLearner._map_batch_to_episode( + self._is_multi_agent, + batch, + finalize=False, + schema=SCHEMA | self.config.input_read_schema, + input_compress_columns=self.config.input_compress_columns, + )["episodes"] + # Otherwise we map the batch to episodes. else: episodes = self._map_to_episodes( self._is_multi_agent, @@ -227,7 +237,7 @@ def _should_module_be_updated(self, module_id, multi_agent_batch=None): @staticmethod def _map_to_episodes( is_multi_agent: bool, - batch: Dict[str, np.ndarray], + batch: Dict[str, Union[list, np.ndarray]], schema: Dict[str, str] = SCHEMA, finalize: bool = False, input_compress_columns: Optional[List[str]] = None, @@ -288,7 +298,7 @@ def convert(sample, space): unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]), observation_space, ) - if Columns.NEXT_OBS in input_compress_columns + if Columns.OBS in input_compress_columns else convert( batch[schema[Columns.NEXT_OBS]][i], observation_space ), @@ -334,7 +344,11 @@ def convert(sample, space): else v[i] ] for k, v in batch.items() - if (k not in schema and k not in schema.values()) + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) }, len_lookback_buffer=0, ) @@ -344,3 +358,96 @@ def convert(sample, space): episodes.append(episode) # Note, `map_batches` expects a `Dict` as return value. return {"episodes": episodes} + + def _map_batch_to_episode( + is_multi_agent: bool, + batch: Dict[str, Union[list, np.ndarray]], + schema: Dict[str, str] = SCHEMA, + finalize: bool = False, + input_compress_columns: Optional[List[str]] = None, + ) -> Dict[str, List[EpisodeType]]: + + episodes = [] + # Note, the `batch` will contain multiple rows and each row contains + # a single `SampleBatch` with experiences from one or more episodes. + # Loop over rows. + # for b in batch: + b = batch + # If multi-agent we need to extract the agent ID. + # TODO (simon): Check, what happens with the module ID. + if is_multi_agent: + agent_id = ( + # The old stack uses "agent_index" instead of "agent_id". + b[schema["agent_index"]][0] + if schema["agent_index"] in b + else None + ) + else: + agent_id = None + + if is_multi_agent: + # TODO (simon): Add support for multi-agent episodes. + pass + else: + obs = ( + unpack_if_needed(b[schema[Columns.OBS]]) + if schema[Columns.OBS] in input_compress_columns + else b[schema[Columns.OBS]] + ) + obs.append( + unpack_if_needed(b[schema[Columns.NEXT_OBS]][-1]) + if schema[Columns.OBS] in input_compress_columns + else b[schema[Columns.NEXT_OBS]][-1] + ) + episode = SingleAgentEpisode( + id_=b[schema[Columns.EPS_ID]][0], + agent_id=agent_id, + observations=obs, + infos=( + b[schema[Columns.INFOS]] + if schema[Columns.INFOS] in b + else [{}] * len(obs) + ), + # Actions might be (a) serialized and/or (b) converted to a JSONable + # (when a composite space was used). We unserializer and then + # reconvert from JSONable to space sample. + actions=( + unpack_if_needed(b[schema[Columns.ACTIONS]]) + if Columns.ACTIONS in input_compress_columns + else b[schema[Columns.ACTIONS]] + ), + rewards=b[schema[Columns.REWARDS]], + terminated=( + any(b[schema[Columns.TERMINATEDS]]) + if schema[Columns.TERMINATEDS] in b + else any(b["dones"]) + ), + truncated=( + any(b[schema[Columns.TRUNCATEDS]]) + if schema[Columns.TRUNCATEDS] in b + else False + ), + # TODO (simon): Results in zero-length episodes in connector. + # t_started=batch[Columns.T if Columns.T in batch else + # "unroll_id"][i][0], + # TODO (simon): Single-dimensional columns are not supported. + # Extra model outputs might be serialized. We unserialize them here + # if needed. + # TODO (simon): Check, if we need here also reconversion from + # JSONable in case of composite spaces. + extra_model_outputs={ + k: unpack_if_needed(v) if k in input_compress_columns else v + for k, v in b.items() + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) + }, + len_lookback_buffer=0, + ) + if finalize: + episode.finalize() + episodes.append(episode) + # Note, `map_batches` expects a `Dict` as return value. + return {"episodes": episodes} diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index b8f6101ae026..7bae309b2df4 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -23,11 +23,12 @@ # Define the data paths. data_path = "tests/data/cartpole/cartpole-v1_large" +data_path = "tests/data/cartpole/large.json" base_path = Path(__file__).parents[2] print(f"base_path={base_path}") data_path = "local://" / base_path / data_path print(f"data_path={data_path}") - +args.no_tune = True # Define the BC config. config = ( BCConfig() @@ -49,6 +50,8 @@ # as remote learners. .offline_data( input_=[data_path.as_posix()], + input_read_method="read_json", + input_read_sample_batches=True, # Define the number of reading blocks, these should be larger than 1 # and aligned with the data size. input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 2)}, @@ -73,9 +76,28 @@ ) ) +# algo = config.build() +# from ray.rllib.offline.offline_prelearner import SCHEMA + +# oplr = algo.offline_data.prelearner_class( +# config=algo.offline_data.config, +# learner=algo.offline_data.learner_handles[0], +# spaces=algo.offline_data.spaces["__env__"], +# ) + + +# rows = algo.offline_data.data.take(2) +# algo.offline_data.prelearner_class._map_batch_to_episode( +# False, +# rows, +# finalize=False, +# schema=SCHEMA | algo.offline_data.config.input_read_schema, +# input_compress_columns=algo.offline_data.config.input_compress_columns, +# ) + stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0, - TRAINING_ITERATION_TIMER: 350.0, + TRAINING_ITERATION_TIMER: 350, } if __name__ == "__main__": From 8e88c46e1e72b10d1a6b0476bf3ac16de3751cad Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Tue, 27 Aug 2024 18:39:16 +0200 Subject: [PATCH 2/6] Added option to 'OfflinePreLearner' to convert old stack 'SampleBatch' to new stack 'EpisodeType' (only 'SingleAgentEpisode' for now). This enables users to use their old recorded agent data for Offline RL. Signed-off-by: simonsays1980 --- .../marwil/marwil_offline_prelearner.py | 9 + rllib/offline/offline_data.py | 43 ++--- rllib/offline/offline_prelearner.py | 175 ++++++++++-------- rllib/offline/tests/test_offline_data.py | 18 ++ rllib/tuned_examples/bc/cartpole_bc.py | 24 +-- 5 files changed, 137 insertions(+), 132 deletions(-) diff --git a/rllib/algorithms/marwil/marwil_offline_prelearner.py b/rllib/algorithms/marwil/marwil_offline_prelearner.py index 5fb7a45119bc..b7a9229095da 100644 --- a/rllib/algorithms/marwil/marwil_offline_prelearner.py +++ b/rllib/algorithms/marwil/marwil_offline_prelearner.py @@ -26,6 +26,15 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, MultiAgentBatch]: # If we directly read in episodes we just convert to list. if self.input_read_episodes: episodes = batch["item"].tolist() + # Else, if we have old stack `SampleBatch`es. + elif self.input_read_sample_batches: + episodes = OfflinePreLearner._map_batch_to_episode( + self._is_multi_agent, + batch, + finalize=False, + schema=SCHEMA | self.config.input_read_schema, + input_compress_columns=self.config.input_compress_columns, + )["episodes"] # Otherwise we ap the batch to episodes. else: # Map the batch to episodes. diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index 5a781c6b63bb..36c2ea3c592c 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -83,35 +83,20 @@ def sample( # TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs # sampling later with a different batch size would need a # reinstantiation of the iterator. - if self.config.input_read_sample_batches: - self.batch_iterator = self.data.map( - self.prelearner_class, - fn_constructor_kwargs={ - "config": self.config, - "learner": self.learner_handles[0], - "spaces": self.spaces[INPUT_ENV_SPACES], - }, - concurrency=1, - # batch_size=num_samples, - # **self.map_batches_kwargs, - ).iter_batches( - batch_size=num_samples, - # **self.iter_batches_kwargs, - ) - else: - self.batch_iterator = self.data.map_batches( - self.prelearner_class, - fn_constructor_kwargs={ - "config": self.config, - "learner": self.learner_handles[0], - "spaces": self.spaces[INPUT_ENV_SPACES], - }, - batch_size=num_samples, - **self.map_batches_kwargs, - ).iter_batches( - batch_size=num_samples, - **self.iter_batches_kwargs, - ) + + self.batch_iterator = self.data.map_batches( + self.prelearner_class, + fn_constructor_kwargs={ + "config": self.config, + "learner": self.learner_handles[0], + "spaces": self.spaces[INPUT_ENV_SPACES], + }, + batch_size=num_samples, + **self.map_batches_kwargs, + ).iter_batches( + batch_size=num_samples, + **self.iter_batches_kwargs, + ) # Do we want to return an iterator or a single batch? if return_iterator: diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index 037d5003e10f..93a6c5887068 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -366,88 +366,103 @@ def _map_batch_to_episode( finalize: bool = False, input_compress_columns: Optional[List[str]] = None, ) -> Dict[str, List[EpisodeType]]: + """Maps an old stack `SampleBatch` to new stack episodes.""" + # Set `input_compress_columns` to an empty `list` if `None`. + input_compress_columns = input_compress_columns or [] + + # TODO (simon): CHeck, if needed. It could possibly happen that a batch contains + # data from different episodes. Merging and resplitting the batch would then + # be the solution. + # Check, if batch comes actually from multiple episodes. + # episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1 + + # Define a container to collect episodes. episodes = [] - # Note, the `batch` will contain multiple rows and each row contains - # a single `SampleBatch` with experiences from one or more episodes. - # Loop over rows. - # for b in batch: - b = batch - # If multi-agent we need to extract the agent ID. - # TODO (simon): Check, what happens with the module ID. - if is_multi_agent: - agent_id = ( - # The old stack uses "agent_index" instead of "agent_id". - b[schema["agent_index"]][0] - if schema["agent_index"] in b - else None - ) - else: - agent_id = None + # Loop over `SampleBatch`es in the `ray.data` batch (a dict). + for i, obs in enumerate(batch[schema[Columns.OBS]]): - if is_multi_agent: - # TODO (simon): Add support for multi-agent episodes. - pass - else: - obs = ( - unpack_if_needed(b[schema[Columns.OBS]]) - if schema[Columns.OBS] in input_compress_columns - else b[schema[Columns.OBS]] - ) - obs.append( - unpack_if_needed(b[schema[Columns.NEXT_OBS]][-1]) - if schema[Columns.OBS] in input_compress_columns - else b[schema[Columns.NEXT_OBS]][-1] - ) - episode = SingleAgentEpisode( - id_=b[schema[Columns.EPS_ID]][0], - agent_id=agent_id, - observations=obs, - infos=( - b[schema[Columns.INFOS]] - if schema[Columns.INFOS] in b - else [{}] * len(obs) - ), - # Actions might be (a) serialized and/or (b) converted to a JSONable - # (when a composite space was used). We unserializer and then - # reconvert from JSONable to space sample. - actions=( - unpack_if_needed(b[schema[Columns.ACTIONS]]) - if Columns.ACTIONS in input_compress_columns - else b[schema[Columns.ACTIONS]] - ), - rewards=b[schema[Columns.REWARDS]], - terminated=( - any(b[schema[Columns.TERMINATEDS]]) - if schema[Columns.TERMINATEDS] in b - else any(b["dones"]) - ), - truncated=( - any(b[schema[Columns.TRUNCATEDS]]) - if schema[Columns.TRUNCATEDS] in b - else False - ), - # TODO (simon): Results in zero-length episodes in connector. - # t_started=batch[Columns.T if Columns.T in batch else - # "unroll_id"][i][0], - # TODO (simon): Single-dimensional columns are not supported. - # Extra model outputs might be serialized. We unserialize them here - # if needed. - # TODO (simon): Check, if we need here also reconversion from - # JSONable in case of composite spaces. - extra_model_outputs={ - k: unpack_if_needed(v) if k in input_compress_columns else v - for k, v in b.items() - if ( - k not in schema - and k not in schema.values() - and k not in ["dones", "agent_index", "type"] - ) - }, - len_lookback_buffer=0, - ) - if finalize: - episode.finalize() - episodes.append(episode) + # If multi-agent we need to extract the agent ID. + # TODO (simon): Check, what happens with the module ID. + if is_multi_agent: + agent_id = ( + # The old stack uses "agent_index" instead of "agent_id". + batch[schema["agent_index"]][i][0] + if schema["agent_index"] in batch + else None + ) + else: + agent_id = None + + if is_multi_agent: + # TODO (simon): Add support for multi-agent episodes. + pass + else: + # Unpack observations, if needed. + obs = ( + unpack_if_needed(obs.tolist()) + if schema[Columns.OBS] in input_compress_columns + else obs.tolist() + ) + # Append the last `new_obs` to get the correct length of observations. + obs.append( + unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1]) + if schema[Columns.OBS] in input_compress_columns + else batch[schema[Columns.NEXT_OBS]][i][-1] + ) + # Create a `SingleAgentEpisode`. + episode = SingleAgentEpisode( + id_=batch[schema[Columns.EPS_ID]][i][0], + agent_id=agent_id, + observations=obs, + infos=( + batch[schema[Columns.INFOS]][i] + if schema[Columns.INFOS] in batch + else [{}] * len(obs) + ), + # Actions might be (a) serialized. We unserialize them here. + actions=( + unpack_if_needed(batch[schema[Columns.ACTIONS]][i]) + if Columns.ACTIONS in input_compress_columns + else batch[schema[Columns.ACTIONS]][i] + ), + rewards=batch[schema[Columns.REWARDS]][i], + terminated=( + any(batch[schema[Columns.TERMINATEDS]][i]) + if schema[Columns.TERMINATEDS] in batch + else any(batch["dones"][i]) + ), + truncated=( + any(batch[schema[Columns.TRUNCATEDS]][i]) + if schema[Columns.TRUNCATEDS] in batch + else False + ), + # TODO (simon): Results in zero-length episodes in connector. + # t_started=batch[Columns.T if Columns.T in batch else + # "unroll_id"][i][0], + # TODO (simon): Single-dimensional columns are not supported. + # Extra model outputs might be serialized. We unserialize them here + # if needed. + # TODO (simon): Check, if we need here also reconversion from + # JSONable in case of composite spaces. + extra_model_outputs={ + k: unpack_if_needed(v[i]) + if k in input_compress_columns + else v[i] + for k, v in batch.items() + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) + }, + len_lookback_buffer=0, + ) + # Finalize, if necessary. + # TODO (simon, sven): Check, if we should convert all data to lists + # before. Right now only obs are lists. + if finalize: + episode.finalize() + episodes.append(episode) # Note, `map_batches` expects a `Dict` as return value. return {"episodes": episodes} diff --git a/rllib/offline/tests/test_offline_data.py b/rllib/offline/tests/test_offline_data.py index 123dbab55b15..222523f2630c 100644 --- a/rllib/offline/tests/test_offline_data.py +++ b/rllib/offline/tests/test_offline_data.py @@ -46,6 +46,24 @@ def test_offline_convert_to_episodes(self): self.assertTrue(len(episodes) == 10) self.assertTrue(isinstance(episodes[0], SingleAgentEpisode)) + def test_offline_convert_from_old_sample_batch_to_episodes(self): + + base_path = Path(__file__).parents[2] + sample_batch_data_path = base_path / "tests/data/cartpole/large.json" + config = AlgorithmConfig().offline_data( + input_=["local://" + sample_batch_data_path.as_posix()], + input_read_method="read_json", + input_read_sample_batches=True, + ) + + offline_data = OfflineData(config) + + batch = offline_data.data.take_batch(batch_size=10) + episodes = OfflinePreLearner._map_batch_to_episode(False, batch)["episodes"] + + self.assertTrue(len(episodes) == 10) + self.assertTrue(isinstance(episodes[0], SingleAgentEpisode)) + def test_sample(self): config = AlgorithmConfig().offline_data(input_=[self.data_path]) diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index 7bae309b2df4..bd7ea96f7f2e 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -23,12 +23,11 @@ # Define the data paths. data_path = "tests/data/cartpole/cartpole-v1_large" -data_path = "tests/data/cartpole/large.json" base_path = Path(__file__).parents[2] print(f"base_path={base_path}") data_path = "local://" / base_path / data_path print(f"data_path={data_path}") -args.no_tune = True + # Define the BC config. config = ( BCConfig() @@ -50,8 +49,6 @@ # as remote learners. .offline_data( input_=[data_path.as_posix()], - input_read_method="read_json", - input_read_sample_batches=True, # Define the number of reading blocks, these should be larger than 1 # and aligned with the data size. input_read_method_kwargs={"override_num_blocks": max(args.num_gpus, 2)}, @@ -76,25 +73,6 @@ ) ) -# algo = config.build() -# from ray.rllib.offline.offline_prelearner import SCHEMA - -# oplr = algo.offline_data.prelearner_class( -# config=algo.offline_data.config, -# learner=algo.offline_data.learner_handles[0], -# spaces=algo.offline_data.spaces["__env__"], -# ) - - -# rows = algo.offline_data.data.take(2) -# algo.offline_data.prelearner_class._map_batch_to_episode( -# False, -# rows, -# finalize=False, -# schema=SCHEMA | algo.offline_data.config.input_read_schema, -# input_compress_columns=algo.offline_data.config.input_compress_columns, -# ) - stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0, TRAINING_ITERATION_TIMER: 350, From 1eca84b595d12f220deb533a4aed763f5192b263 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 28 Aug 2024 10:57:35 +0200 Subject: [PATCH 3/6] Removed some relicts from earlier tryouts in 'OfflineData'. Signed-off-by: simonsays1980 --- rllib/offline/offline_data.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index 36c2ea3c592c..5c8c3e98761a 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -42,19 +42,9 @@ def __init__(self, config: AlgorithmConfig): logger.error(e) # Avoids reinstantiating the batch iterator each time we sample. self.batch_iterator = None - self.map_method = ( - "map" - if self.config.input_read_episodes or self.config.input_read_sample_batches - else "map_batches" - ) self.map_batches_kwargs = ( self.default_map_batches_kwargs | self.config.map_batches_kwargs ) - self.iter_method = ( - "iter_rows" - if self.config.input_read_episodes or self.config.input_read_sample_batches - else "iter_batches" - ) self.iter_batches_kwargs = ( self.default_iter_batches_kwargs | self.config.iter_batches_kwargs ) From abcd0668e8ea24ee1ab8e29a35c822723b5e04ae Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 28 Aug 2024 11:08:08 +0200 Subject: [PATCH 4/6] CHanged some comment and added a missing apostrophe that was raising an error in building the docs. Signed-off-by: simonsays1980 --- rllib/algorithms/algorithm_config.py | 6 +++--- rllib/offline/offline_data.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index d4fdd4bfa610..f61f5cdffd76 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -2466,9 +2466,9 @@ def offline_data( inside of RLlib's schema. The other format is a columnar format and is agnostic to the RL framework used. Use the latter format, if you are unsure when to use the data or in which RL framework. The default is - to read column data, i.e. `False`. input_read_episodes` and - `inpuit_read_sample_batches` cannot be `True` at the same time.See - also `output_write_episodes`to define the output data format when + to read column data, i.e. `False`. `input_read_episodes` and + `inpuit_read_sample_batches` cannot be `True` at the same time. See + also `output_write_episodes` to define the output data format when recording. input_read_sample_batches: Whether offline data is stored in RLlib's old stack `SampleBatch` type. This is usually the case for older data diff --git a/rllib/offline/offline_data.py b/rllib/offline/offline_data.py index 5c8c3e98761a..9dd102616e75 100644 --- a/rllib/offline/offline_data.py +++ b/rllib/offline/offline_data.py @@ -25,7 +25,7 @@ def __init__(self, config: AlgorithmConfig): self.path = ( config.input_ if isinstance(config.input_, list) else Path(config.input_) ) - # Use `read_json` as default data read method. + # Use `read_parquet` as default data read method. self.data_read_method = config.input_read_method # Override default arguments for the data read method. self.data_read_method_kwargs = ( From 54b4724d936e24232ebb657ecc75690f8256766b Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 28 Aug 2024 11:09:46 +0200 Subject: [PATCH 5/6] Added missing JSON file for testing conversion from 'SampleBatch' to 'SingleAgentEpisode' to BUILD file. Signed-off-by: simonsays1980 --- rllib/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/BUILD b/rllib/BUILD index e681e4c252a0..3647830a52bf 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1867,6 +1867,7 @@ py_test( srcs = ["offline/tests/test_offline_data.py"], data = [ "tests/data/cartpole/cartpole-v1_large", + "tests/data/cartpole/large.json", ], ) From edc5691e8d4eee74f6bcf468963f093cd4340331 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 28 Aug 2024 12:25:27 +0200 Subject: [PATCH 6/6] Added suggestions from @sven1977's review. Signed-off-by: simonsays1980 --- rllib/algorithms/marwil/marwil_offline_prelearner.py | 2 +- rllib/offline/offline_prelearner.py | 8 ++++---- rllib/offline/tests/test_offline_data.py | 4 +++- rllib/tuned_examples/bc/cartpole_bc.py | 4 ++-- rllib/tuned_examples/bc/pendulum_bc.py | 4 ++-- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/rllib/algorithms/marwil/marwil_offline_prelearner.py b/rllib/algorithms/marwil/marwil_offline_prelearner.py index e39d8998775f..5ba39feab5ae 100644 --- a/rllib/algorithms/marwil/marwil_offline_prelearner.py +++ b/rllib/algorithms/marwil/marwil_offline_prelearner.py @@ -28,7 +28,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, MultiAgentBatch]: episodes = batch["item"].tolist() # Else, if we have old stack `SampleBatch`es. elif self.input_read_sample_batches: - episodes = OfflinePreLearner._map_batch_to_episode( + episodes = OfflinePreLearner._map_sample_batch_to_episode( self._is_multi_agent, batch, finalize=False, diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index c475aa62a0a9..ea7813f50d04 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -147,7 +147,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] episodes = batch["item"].tolist() # Else, if we have old stack `SampleBatch`es. elif self.input_read_sample_batches: - episodes = OfflinePreLearner._map_batch_to_episode( + episodes = OfflinePreLearner._map_sample_batch_to_episode( self._is_multi_agent, batch, finalize=False, @@ -281,7 +281,7 @@ def convert(sample, space): if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. - pass + NotImplementedError else: # Build a single-agent episode with a single row of the batch. episode = SingleAgentEpisode( @@ -359,7 +359,7 @@ def convert(sample, space): # Note, `map_batches` expects a `Dict` as return value. return {"episodes": episodes} - def _map_batch_to_episode( + def _map_sample_batch_to_episode( is_multi_agent: bool, batch: Dict[str, Union[list, np.ndarray]], schema: Dict[str, str] = SCHEMA, @@ -396,7 +396,7 @@ def _map_batch_to_episode( if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. - pass + NotImplementedError else: # Unpack observations, if needed. obs = ( diff --git a/rllib/offline/tests/test_offline_data.py b/rllib/offline/tests/test_offline_data.py index 222523f2630c..5e67dbb2ef50 100644 --- a/rllib/offline/tests/test_offline_data.py +++ b/rllib/offline/tests/test_offline_data.py @@ -59,7 +59,9 @@ def test_offline_convert_from_old_sample_batch_to_episodes(self): offline_data = OfflineData(config) batch = offline_data.data.take_batch(batch_size=10) - episodes = OfflinePreLearner._map_batch_to_episode(False, batch)["episodes"] + episodes = OfflinePreLearner._map_sample_batch_to_episode(False, batch)[ + "episodes" + ] self.assertTrue(len(episodes) == 10) self.assertTrue(isinstance(episodes[0], SingleAgentEpisode)) diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index bd7ea96f7f2e..f52b655a2d18 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -1,11 +1,11 @@ from pathlib import Path +from ray.air.constants import TRAINING_ITERATION from ray.rllib.algorithms.bc import BCConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, EVALUATION_RESULTS, - TRAINING_ITERATION_TIMER, ) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -75,7 +75,7 @@ stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 120.0, - TRAINING_ITERATION_TIMER: 350, + TRAINING_ITERATION: 350, } if __name__ == "__main__": diff --git a/rllib/tuned_examples/bc/pendulum_bc.py b/rllib/tuned_examples/bc/pendulum_bc.py index ffd26e471b53..6988f0989ac6 100644 --- a/rllib/tuned_examples/bc/pendulum_bc.py +++ b/rllib/tuned_examples/bc/pendulum_bc.py @@ -1,11 +1,11 @@ from pathlib import Path +from ray.air.constants import TRAINING_ITERATION from ray.rllib.algorithms.bc import BCConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, EVALUATION_RESULTS, - TRAINING_ITERATION_TIMER, ) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -61,7 +61,7 @@ stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -200.0, - TRAINING_ITERATION_TIMER: 350.0, + TRAINING_ITERATION: 350, } if __name__ == "__main__":