Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Offline performance cleanup. #47731

Merged
merged 16 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_filesystem_kwargs = {}
self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.input_spaces_jsonable = True
self.materialize_data = False
self.materialize_mapped_data = True
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
Expand Down Expand Up @@ -2467,6 +2469,8 @@ def offline_data(
input_filesystem: Optional[str] = NotProvided,
input_filesystem_kwargs: Optional[Dict] = NotProvided,
input_compress_columns: Optional[List[str]] = NotProvided,
materialize_data: Optional[bool] = NotProvided,
materialize_mapped_data: Optional[bool] = NotProvided,
map_batches_kwargs: Optional[Dict] = NotProvided,
iter_batches_kwargs: Optional[Dict] = NotProvided,
prelearner_class: Optional[Type] = NotProvided,
Expand Down Expand Up @@ -2556,6 +2560,31 @@ def offline_data(
`MultiAgentEpisode` not supported, yet). Note,
`rllib.core.columns.Columns.OBS` will also try to decompress
`rllib.core.columns.Columns.NEXT_OBS`.
materialize_data: Whether the raw data should be materialized in memory.
This boosts performance, but requires enough memory to avoid an OOM, so
make sure that your cluster has the resources available. For very large
data you might want to switch to streaming mode by setting this to
`False` (default). If your algorithm does not need the RLModule in the
Learner connector pipeline or all (learner) connectors are stateless
you should consider setting `materialize_mapped_data` to `True`
instead (and set `materialize_data` to `False`). If your data does not
fit into memory and your Learner connector pipeline requires an RLModule
or is stateful, set both `materialize_data` and
`materialize_mapped_data` to `False`.
materialize_mapped_data: Whether the data should be materialized after
running it through the Learner connector pipeline (i.e. after running
the `OfflinePreLearner`). This improves performance, but should only be
used in case the (learner) connector pipeline does not require an
RLModule and the (learner) connector pipeline is stateless. For example,
MARWIL's Learner connector pipeline requires the RLModule for value
function predictions and training batches would become stale after some
iterations causing learning degradation or divergence. Also ensure that
your cluster has enough memory available to avoid an OOM. If set to
`True` (True), make sure that `materialize_data` is set to `False` to
avoid materialization of two datasets. If your data does not fit into
memory and your Learner connector pipeline requires an RLModule or is
stateful, set both `materialize_data` and `materialize_mapped_data` to
`False`.
map_batches_kwargs: Keyword args for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments
Expand Down Expand Up @@ -2658,6 +2687,10 @@ def offline_data(
self.input_filesystem_kwargs = input_filesystem_kwargs
if input_compress_columns is not NotProvided:
self.input_compress_columns = input_compress_columns
if materialize_data is not NotProvided:
self.materialize_data = materialize_data
if materialize_mapped_data is not NotProvided:
self.materialize_mapped_data = materialize_mapped_data
if map_batches_kwargs is not NotProvided:
self.map_batches_kwargs = map_batches_kwargs
if iter_batches_kwargs is not NotProvided:
Expand Down
10 changes: 4 additions & 6 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType

Expand Down Expand Up @@ -61,11 +60,10 @@ def __init__(self, algo_class=None):
# not important for behavioral cloning.
self.postprocess_inputs = False

# Set the offline prelearner to the default one. Note, MARWIL's
# specified offline prelearner requests a value function that
# BC does not have. Furthermore, MARWIL's prelearner calculates
# advantages unneeded for BC.
self.prelearner_class = OfflinePreLearner
# Materialize only the mapped data. This is optimal as long
# as no connector in the connector pipeline holds a state.
self.materialize_data = False
self.materialize_mapped_data = True
# __sphinx_doc_end__
# fmt: on

Expand Down
12 changes: 12 additions & 0 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def __init__(self, algo_class=None):
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}

# Materialize only the data in raw format, but not the mapped data b/c
# MARWIL uses a connector to calculate values and therefore the module
# needs to be updated frequently. This updating would not work if we
# map the data once at the beginning.
# TODO (simon, sven): The module is only updated when the OfflinePreLearner
# gets reinitiated, i.e. when the iterator gets reinitiated. This happens
# frequently enough with a small dataset, but with a big one this does not
# update often enough. We might need to put model weigths every couple of
# iterations into the object storage (maybe also connector states).
self.materialize_data = True
self.materialize_mapped_data = False
# __sphinx_doc_end__
# fmt: on
self._set_off_policy_estimation_methods = False
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/marwil/torch/marwil_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def possibly_masked_mean(data_):
else:
# cumulative_rewards = batch[Columns.ADVANTAGES]
value_fn_out = fwd_out[Columns.VF_PREDS]
# advantages = cumulative_rewards - value_fn_out
advantages = batch[Columns.ADVANTAGES]
advantages = batch[Columns.VALUE_TARGETS] - value_fn_out
advantages_squared_mean = possibly_masked_mean(torch.pow(advantages, 2.0))

# Compute the value loss.
Expand Down
13 changes: 10 additions & 3 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,6 @@ def update_from_iterator(
)

self._check_is_built()
# minibatch_size = minibatch_size or 32

# Call `before_gradient_based_update` to allow for non-gradient based
# preparations-, logging-, and update logic to happen.
Expand All @@ -1106,8 +1105,11 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
return {"batch": self._set_slicing_by_batch_id(batch, value=True)}

i = 0
logger.debug(f"===> [Learner {id(self)}]: SLooping through batches ... ")
for batch in iterator.iter_batches(
batch_size=minibatch_size,
# Note, this needs to be one b/c data is already mapped to
# `MultiAgentBatch`es of `minibatch_size`.
batch_size=1,
_finalize_fn=_finalize_fn,
**kwargs,
):
Expand All @@ -1116,6 +1118,9 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:

# Note, `_finalize_fn` must return a dictionary.
batch = batch["batch"]
logger.debug(
f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows."
)
# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
# found in this batch. If not, throw an error.
unknown_module_ids = set(batch.policy_batches.keys()) - set(
Expand All @@ -1141,7 +1146,9 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
if num_iters and i == num_iters:
break

logger.info(f"[Learner] Iterations run in epoch: {i}")
logger.info(
f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}"
)
# Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
# to actual (numpy) values.
self.metrics.tensors_to_numpy(tensor_metrics)
Expand Down
147 changes: 103 additions & 44 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from pathlib import Path
import ray
import time
import types

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core import COMPONENT_RL_MODULE
Expand All @@ -21,17 +23,27 @@ class OfflineData:
def __init__(self, config: AlgorithmConfig):

self.config = config
self.is_multi_agent = config.is_multi_agent()
self.is_multi_agent = self.config.is_multi_agent()
self.path = (
config.input_ if isinstance(config.input_, list) else Path(config.input_)
self.config.input_
if isinstance(config.input_, list)
else Path(config.input_)
)
# Use `read_parquet` as default data read method.
self.data_read_method = config.input_read_method
self.data_read_method = self.config.input_read_method
# Override default arguments for the data read method.
self.data_read_method_kwargs = (
self.default_read_method_kwargs | config.input_read_method_kwargs
self.default_read_method_kwargs | self.config.input_read_method_kwargs
)

# If data should be materialized.
self.materialize_data = config.materialize_data
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# If mapped data should be materialized.
self.materialize_mapped_data = config.materialize_mapped_data
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
# Flag to identify, if data has already been mapped with the
# `OfflinePreLearner`.
self.data_is_mapped = False

# Set the filesystem.
self.filesystem = self.config.output_filesystem
self.filesystem_kwargs = self.config.output_filesystem_kwargs
Expand Down Expand Up @@ -67,9 +79,14 @@ def __init__(self, config: AlgorithmConfig):

try:
# Load the dataset.
start_time = time.perf_counter()
self.data = getattr(ray.data, self.data_read_method)(
self.path, **self.data_read_method_kwargs
)
if self.materialize_data:
self.data = self.data.materialize()
stop_time = time.perf_counter()
logger.debug(f"Time for loading dataset: {stop_time - start_time}s.")
logger.info("Reading data from {}".format(self.path))
logger.info(self.data.schema())
except Exception as e:
Expand All @@ -96,32 +113,23 @@ def sample(
return_iterator: bool = False,
num_shards: int = 1,
):
if (
not return_iterator or (return_iterator and num_shards <= 1)
) and not self.batch_iterator:
# If no iterator should be returned, or if we want to return a single
# batch iterator, we instantiate the batch iterator once, here.
# TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs
# sampling later with a different batch size would need a
# reinstantiation of the iterator.
self.batch_iterator = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs={
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces[INPUT_ENV_SPACES],
},
batch_size=num_samples,
**self.map_batches_kwargs,
).iter_batches(
batch_size=num_samples,
**self.iter_batches_kwargs,
)

# Do we want to return an iterator or a single batch?
if return_iterator:
# In case of multiple shards, we return multiple
# `StreamingSplitIterator` instances.
# Materialize the mapped data, if necessary. This runs for all the
# data the `OfflinePreLearner` logic and maps them to `MultiAgentBatch`es.
# TODO (simon, sven): This would never update the module nor the
# the connectors. If this is needed we have to check, if we give
# (a) only an iterator and let the learner and OfflinePreLearner
# communicate through the object storage. This only works when
# not materializing.
# (b) Rematerialize the data every couple of iterations. This is
# is costly.
if not self.data_is_mapped:
# Constructor `kwargs` for the `OfflinePreLearner`.
fn_constructor_kwargs = {
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces[INPUT_ENV_SPACES],
}
# If we have multiple learners, add to the constructor `kwargs`.
if num_shards > 1:
# Call here the learner to get an up-to-date module state.
# TODO (simon): This is a workaround as along as learners cannot
Expand All @@ -131,31 +139,82 @@ def sample(
component=COMPONENT_RL_MODULE
)
)
return self.data.map_batches(
# TODO (cheng su): At best the learner handle passed in here should
# be the one from the learner that is nearest, but here we cannot
# provide locality hints.
self.prelearner_class,
fn_constructor_kwargs={
"config": self.config,
# Add constructor `kwargs` when using remote learners.
fn_constructor_kwargs.update(
{
"learner": self.learner_handles,
"spaces": self.spaces["__env__"],
"locality_hints": self.locality_hints,
"module_spec": self.module_spec,
"module_state": module_state,
},
batch_size=num_samples,
**self.map_batches_kwargs,
).streaming_split(
n=num_shards, equal=False, locality_hints=self.locality_hints
}
)

self.data = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs=fn_constructor_kwargs,
batch_size=num_samples,
**self.map_batches_kwargs,
)
# Set the flag to `True`.
self.data_is_mapped = True
# If the user wants to materialize the data in memory.
if self.materialize_mapped_data:
self.data = self.data.materialize()
# Build an iterator, if necessary.
if (not self.batch_iterator and (not return_iterator or num_shards <= 1)) or (
return_iterator and isinstance(self.batch_iterator, types.GeneratorType)
):
# If no iterator should be returned, or if we want to return a single
# batch iterator, we instantiate the batch iterator once, here.
# TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs
# sampling later with a different batch size would need a
# reinstantiation of the iterator.
self.batch_iterator = self.data.iter_batches(
# This is important. The batch size is now 1, because the data
# is already run through the `OfflinePreLearner` and a single
# instance is a single `MultiAgentBatch` of size `num_samples`.
batch_size=1,
**self.iter_batches_kwargs,
)

if not return_iterator:
self.batch_iterator = iter(self.batch_iterator)

# Do we want to return an iterator or a single batch?
if return_iterator:
# In case of multiple shards, we return multiple
# `StreamingSplitIterator` instances.
if num_shards > 1:
# TODO (simon): Check, if we should use `iter_batches_kwargs` here
# as well.
logger.debug("===> [OfflineData]: Return streaming_split ... ")
return self.data.streaming_split(
n=num_shards,
# Note, `equal` must be `True`, i.e. the batch size must
# be the same for all batches b/c otherwise remote learners
# could block each others.
equal=True,
locality_hints=self.locality_hints,
)

# Otherwise, we return a simple batch `DataIterator`.
else:
return self.batch_iterator
else:
# Return a single batch from the iterator.
return next(iter(self.batch_iterator))["batch"][0]
try:
return next(self.batch_iterator)["batch"][0]
except StopIteration:
# If the batch iterator is exhausted, reinitiate a new one.
logger.debug(
"===> [OfflineData]: Batch iterator exhausted. Reinitiating ..."
)
self.batch_iterator = None
return self.sample(
num_samples=num_samples,
return_iterator=return_iterator,
num_shards=num_shards,
)

@property
def default_read_method_kwargs(self):
Expand Down
3 changes: 3 additions & 0 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gymnasium as gym
import logging
import numpy as np
import random
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
Expand Down Expand Up @@ -46,6 +47,8 @@
"unroll_id": "unroll_id",
}

logger = logging.getLogger(__name__)


@ExperimentalAPI
class OfflinePreLearner:
Expand Down
Loading
Loading