Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Add support to directly read from episodes. #46865

Merged
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
cbfd05f
Initiated MARWIL RL Module and added catalog, learner and tf_learner.
simonsays1980 Sep 8, 2023
c488da7
Added MARWIL RL Module and started to write test.
simonsays1980 Sep 8, 2023
9078af8
Merge branch 'master' into marwil-rl-module
simonsays1980 Apr 24, 2024
b4e1795
Implemented Torch version of MARWIL.
simonsays1980 Apr 25, 2024
5eeb2e6
Added torch learner.
simonsays1980 Apr 25, 2024
a1928bc
Merged master.
simonsays1980 Jul 23, 2024
3fcef32
Moved trainign step logic from BC to MARWIL.
simonsays1980 Jul 23, 2024
e9abc27
Setup MARWIL with the new stack using 'OfflineData'. This is an unfor…
simonsays1980 Jul 23, 2024
c930464
Fixed multiple bugs in 'MARWILOfflinePreLearner', 'MARWILTorchLearner…
simonsays1980 Jul 24, 2024
8b575db
LINTER.
simonsays1980 Jul 24, 2024
1325beb
Merged Master
simonsays1980 Jul 24, 2024
7785bda
Removed tensorflow and fixed a small bug.
simonsays1980 Jul 24, 2024
3e680cc
Readded 'input_read_schema' b/c it was accidentally removed.
simonsays1980 Jul 24, 2024
90c8d03
Readded further tests for MARWIL on continuous actions and its loss f…
simonsays1980 Jul 25, 2024
8a8d5c5
Added default 'prelearner_class' to 'MARWILConfig'.
simonsays1980 Jul 25, 2024
5f051e0
Added example to 'tuned_examples' for MARWIL.
simonsays1980 Jul 25, 2024
f247e7c
Added BC and MARWIL tuned_examples to learning tests.
simonsays1980 Jul 25, 2024
96d1e8e
Moved definition of prelearner class from 'MARWILConfig.offline_data'…
simonsays1980 Jul 25, 2024
9a29bf1
Fixed path for data in BUILD.
simonsays1980 Jul 25, 2024
0dc7c8a
Removed a duplicated forward slash from BUILD file that led to errors…
simonsays1980 Jul 26, 2024
061804e
Merged master.
simonsays1980 Jul 26, 2024
ff41d12
Added main to MARWIL RLModule test.
simonsays1980 Jul 26, 2024
122dc50
Added tests for old stack MARWIL.
simonsays1980 Jul 26, 2024
e7b05c7
Merge branch 'master' into marwil-rl-module
simonsays1980 Jul 29, 2024
f7f9d89
Fixed a circular import in 'test_offline_data'.
simonsays1980 Jul 29, 2024
e0a0351
Added 'finalize' to 'OfflinePreLearner._map_to_episodes' b/c some con…
simonsays1980 Jul 29, 2024
5632061
Set 'OfflinePreLearner' as BC's default prelearner class in the confi…
simonsays1980 Jul 30, 2024
9409a1e
Fixed a small bug in MARWIL when writing to the metrics logger and se…
simonsays1980 Jul 30, 2024
8222454
Deprecated TensorFlow support for BC and removed bc-specific learners…
simonsays1980 Jul 30, 2024
5afc3ab
Fixed old stack test by removing hybrid stack.
simonsays1980 Jul 30, 2024
399e078
In regard to the upcoming offline recording feature added the option …
simonsays1980 Jul 30, 2024
b027890
Fixed file path for MARWIL learning test in BUILD.
simonsays1980 Jul 31, 2024
1b589bf
Merge branch 'marwil-rl-module' into offline-prelearner-from-episodes
simonsays1980 Jul 31, 2024
e86e2dd
Added @sven1977's review.
simonsays1980 Jul 31, 2024
429b5f7
Added new test file to test folder.
simonsays1980 Jul 31, 2024
94dd369
Merged Master
simonsays1980 Jul 31, 2024
0e73c14
Fixed pretrain BC example as it was using the hybrid stack which is d…
simonsays1980 Jul 31, 2024
c781c58
Removed training step for hybrid stack and renamed test files.
simonsays1980 Jul 31, 2024
b216d6a
Fixed small typo.
simonsays1980 Jul 31, 2024
fff9db1
Merge branch 'marwil-rl-module' into offline-prelearner-from-episodes
simonsays1980 Jul 31, 2024
d2e9a08
Removed override of 'get_default_learner_class' in BC. BC takes now t…
simonsays1980 Aug 1, 2024
cf39a2e
Merged Master
simonsays1980 Aug 7, 2024
351c8a1
Removed duplicate and wrong MARWIL tuned example learning test.
simonsays1980 Aug 7, 2024
fcca1b4
Fixed some linting errors.
simonsays1980 Aug 7, 2024
393b7d8
Set MARWIL learning test to 'large' after it timeouts.
simonsays1980 Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.input_read_episodes = False
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
Expand Down Expand Up @@ -2385,6 +2386,7 @@ def offline_data(
input_read_method: Optional[Union[str, Callable]] = NotProvided,
input_read_method_kwargs: Optional[Dict] = NotProvided,
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_read_episodes: Optional[bool] = NotProvided,
map_batches_kwargs: Optional[Dict] = NotProvided,
iter_batches_kwargs: Optional[Dict] = NotProvided,
prelearner_class: Optional[Type] = NotProvided,
Expand Down Expand Up @@ -2437,6 +2439,16 @@ def offline_data(
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed.
input_read_episodes: If offline data is already stored in RLlib's
`EpisodeType` format, i.e. `ray.rllib.env.SingleAgentEpisode` (multi
-agent is planned but not supported, yet). Reading directly episodes
avoids an additional transforming step and is usually faster and
therefore the adviced format when your application remains fully inside
of RLlib's schema. The other format is a columnar format and is agnostic
to the RL framework used. Use the latter format, if you are unsure when
to use the data or in which RL framework. The default is to read column
data, i.e. `False`. See also `output_write_episodes` to define the
output data format when recording.
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
Expand Down Expand Up @@ -2528,6 +2540,8 @@ def offline_data(
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if input_read_episodes is not NotProvided:
self.input_read_episodes = input_read_episodes
if map_batches_kwargs is not NotProvided:
self.map_batches_kwargs = map_batches_kwargs
if iter_batches_kwargs is not NotProvided:
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_default_rl_module_spec(self) -> RLModuleSpecType:
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
"Use 'torch' instead."
)

@override(AlgorithmConfig)
Expand All @@ -205,7 +205,8 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]:
return MARWILTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use 'torch'."
f"The framework {self.framework_str} is not supported. "
"Use 'torch' instead."
)

@override(AlgorithmConfig)
Expand Down Expand Up @@ -324,7 +325,7 @@ def training_step(self) -> ResultDict:
elif self.config.enable_rl_module_and_learner:
raise ValueError(
"`enable_rl_module_and_learner=True`. Hybrid stack is not "
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
"is not supported for MARWIL. Either use the old stack with "
"supported for MARWIL. Either use the old stack with "
"`ModelV2` or the new stack with `RLModule`. You can enable "
"the new stack by setting both, `enable_rl_module_and_learner` "
"and `enable_env_runner_and_connector_v2` to `True`."
Expand Down
18 changes: 16 additions & 2 deletions rllib/algorithms/marwil/marwil_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,22 @@ class MARWILCatalog(Catalog):
"""The Catalog class used to build models for MARWIL.

MARWILCatalog provides the following models:


- ActorCriticEncoder: The encoder used to encode the observations.
- Pi Head: The head used to compute the policy logits.
- Value Function Head: The head used to compute the value function.

The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of MARWILRLModule for
more details.

ny custom ActorCriticEncoder can be built by overriding the
build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
at MARWILCatalog.actor_critic_encoder_config can be overridden to build a custom
ActorCriticEncoder during RLModule runtime.

Any custom head can be built by overriding the build_pi_head() and build_vf_head()
methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
build custom heads during RLModule runtime.
"""

def __init__(
Expand Down
20 changes: 14 additions & 6 deletions rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
):

self.config = config
self.input_read_episodes = self.config.input_read_episodes
# We need this learner to run the learner connector pipeline.
# If it is a `Learner` instance, the `Learner` is local.
if isinstance(learner, Learner):
Expand Down Expand Up @@ -130,10 +131,17 @@ def __init__(

@OverrideToImplementCustomLogic
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]:
# Map the batch to episodes.
episodes = self._map_to_episodes(
self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema
)

# If we directly read in episodes we just convert to list.
if self.input_read_episodes:
episodes = batch["item"].tolist()
# Otherwise we ap the batch to episodes.
else:
episodes = self._map_to_episodes(
self._is_multi_agent,
batch,
schema=SCHEMA | self.config.input_read_schema,
)["episodes"]
# TODO (simon): Make synching work. Right now this becomes blocking or never
# receives weights. Learners appear to be non accessable via other actors.
# Increase the counter for updating the module.
Expand Down Expand Up @@ -165,7 +173,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
batch = self._learner_connector(
rl_module=self._module,
data={},
episodes=episodes["episodes"],
episodes=episodes,
shared_data={},
)
# Convert to `MultiAgentBatch`.
Expand All @@ -176,7 +184,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]
},
# TODO (simon): This can be run once for the batch and the
# metrics, but we run it twice: here and later in the learner.
env_steps=sum(e.env_steps() for e in episodes["episodes"]),
env_steps=sum(e.env_steps() for e in episodes),
)
# Remove all data from modules that should not be trained. We do
# not want to pass around more data than necessaty.
Expand Down
Loading