Skip to content

Commit

Permalink
[RLlib] Cleanup, rename, clarify: Algorithm.workers/evaluation_worker…
Browse files Browse the repository at this point in the history
…s, local_worker(), etc.. (#46726)
  • Loading branch information
sven1977 authored Jul 22, 2024
1 parent 232c331 commit 710f557
Show file tree
Hide file tree
Showing 56 changed files with 562 additions and 785 deletions.
2 changes: 1 addition & 1 deletion doc/source/rllib/doc_code/dreamerv3_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
algo = config.build()

# Extract the actual RLModule from the local (Dreamer) EnvRunner.
rl_module = algo.workers.local_worker().module
rl_module = algo.env_runner.module
# Get initial states from RLModule (note that these are always B=1, so this matches
# our num_envs=1; if you are using a vector env >1, you would have to repeat the
# returned states `num_env` times to get the correct batch size):
Expand Down
8 changes: 5 additions & 3 deletions doc/source/rllib/doc_code/getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@
algo.get_policy().get_weights()

# Same as above
algo.workers.local_worker().policy_map["default_policy"].get_weights()
algo.env_runner.policy_map["default_policy"].get_weights()

# Get list of weights of each worker, including remote replicas
algo.workers.foreach_worker(lambda worker: worker.get_policy().get_weights())
algo.env_runner_group.foreach_worker(
lambda env_runner: env_runner.get_policy().get_weights()
)

# Same as above, but with index.
algo.workers.foreach_worker_with_id(
algo.env_runner_group.foreach_worker_with_id(
lambda _id, worker: worker.get_policy().get_weights()
)
# __rllib-get-state-end__
18 changes: 9 additions & 9 deletions doc/source/rllib/key-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,15 @@ An example implementation of VPG could look like the following:
def training_step(self) -> ResultDict:
# 1. Sampling.
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
worker_set=self.env_runner_group,
max_env_steps=self.config["train_batch_size"]
)
# 2. Updating the Policy.
train_results = train_one_step(self, train_batch)
# 3. Synchronize worker weights.
self.workers.sync_weights()
self.env_runner_group.sync_weights()
# 4. Return results.
return train_results
Expand All @@ -290,11 +290,11 @@ In the first step, we collect trajectory data from the environment(s):
.. code-block:: python
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
worker_set=self.env_runner_group,
max_env_steps=self.config["train_batch_size"]
)
Here, ``self.workers`` is a set of ``RolloutWorkers`` that are created in the ``Algorithm``'s ``setup()`` method
Here, ``self.env_runner_group`` is a set of ``EnvRunners`` that are created in the ``Algorithm``'s ``setup()`` method
(prior to calling ``training_step()``).
This :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` is covered in greater depth on the :ref:`EnvRunnerGroup documentation page <workerset-reference-docs>`.
The utility function ``synchronous_parallel_sample`` can be used for parallel sampling in a blocking
Expand All @@ -312,19 +312,19 @@ The ``train_batch`` is then passed to another utility function: ``train_one_step
Methods like ``train_one_step`` and ``multi_gpu_train_one_step`` are used for training our Policy.
Further documentation with examples can be found on the :ref:`train ops documentation page <train-ops-docs>`.

The training updates on the policy are only applied to its version inside ``self.workers.local_worker``.
The training updates on the policy are only applied to its version inside ``self.env_runner``.
Note that each :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup` has n remote :py:class:`~ray.rllib.env.env_runner.EnvRunner` instances and exactly one "local worker" and that all EnvRunners (remote and local ones)
hold a copy of the policy.

Now that we updated the local policy (the copy in ``self.workers.local_worker``), we need to make sure
that the copies in all remote workers (``self.workers.remote_workers``) have their weights synchronized
Now that we updated the local policy (the copy in ``self.env_runner_group.local_env_runner``), we need to make sure
that the copies in all remote workers (``self.env_runner_group.remote_workers``) have their weights synchronized
(from the local one):

.. code-block:: python
self.workers.sync_weights()
self.env_runner_group.sync_weights()
By calling ``self.workers.sync_weights()``,
By calling ``self.env_runner_group.sync_weights()``,
weights are broadcasted from the local worker to the remote workers. See :ref:`rollout worker
reference docs <rolloutworker-reference-docs>` for further details.

Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib/package_ref/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Data ingest via either environment rollouts or other data-generating methods
(e.g. reading from offline files) is done in RLlib by :py:class:`~ray.rllib.env.env_runner.EnvRunner` instances,
which sit inside a :py:class:`~ray.rllib.env.env_runner_group.EnvRunnerGroup`
(together with other parallel ``EnvRunners``) in the RLlib :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`
(under the ``self.workers`` property):
(under the ``self.env_runner_group`` property):


.. https://docs.google.com/drawings/d/1OewMLAu6KZNon7zpDfZnTh9qiT6m-3M9wnkqWkQQMRc/edit
Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib/rllib-advanced-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ overriding the :py:meth:`~ray.rllib.algorithms.callbacks.DefaultCallbacks.on_tra
task = 1
else:
task = 0
algorithm.workers.foreach_worker(
algorithm.env_runner_group.foreach_worker(
lambda ev: ev.foreach_env(
lambda env: env.set_task(task)))
Expand Down
10 changes: 5 additions & 5 deletions doc/source/rllib/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ the same pre-loaded batch:
# Collect SampleBatches from sample workers until we have a full batch.
if self._by_agent_steps:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_agent_steps=self.config["train_batch_size"]
worker_set=self.env_runner_group, max_agent_steps=self.config["train_batch_size"]
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers, max_env_steps=self.config["train_batch_size"]
worker_set=self.env_runner_group, max_env_steps=self.config["train_batch_size"]
)
train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
Expand All @@ -273,9 +273,9 @@ the same pre-loaded batch:
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
if self.env_runner_group.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
self.env_runner_group.sync_weights(global_vars=global_vars)
# For each policy: update KL scale and warn about possible issues
for policy_id, policy_info in train_results.items():
Expand All @@ -285,7 +285,7 @@ the same pre-loaded batch:
self.get_policy(policy_id).update_kl(kl_divergence)
# Update global vars on local worker as well.
self.workers.local_worker().set_global_vars(global_vars)
self.env_runner.set_global_vars(global_vars)
return train_results
Expand Down
6 changes: 3 additions & 3 deletions doc/source/rllib/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ or get model weights.
In RLlib algorithm state is replicated across multiple *rollout workers* (Ray actors)
in the cluster.
However, you can easily get and update this state between calls to ``train()``
via ``Algorithm.workers.foreach_worker()``
or ``Algorithm.workers.foreach_worker_with_index()``.
via ``Algorithm.env_runner_group.foreach_worker()``
or ``Algorithm.env_runner_group.foreach_worker_with_index()``.
These functions take a lambda function that is applied with the worker as an argument.
These functions return values for each worker as a list.

You can also access just the "master" copy of the algorithm state through
``Algorithm.get_policy()`` or ``Algorithm.workers.local_worker()``,
``Algorithm.get_policy()`` or ``Algorithm.env_runner``,
but note that updates here may not be immediately reflected in
your rollout workers (if you have configured ``num_env_runners > 0``).
Here's a quick example of how to access state of a model:
Expand Down
8 changes: 0 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1973,14 +1973,6 @@ py_test(
srcs = ["tests/test_lstm.py"]
)

py_test(
name = "tests/test_model_imports",
tags = ["team:rllib", "tests_dir", "model_imports"],
size = "medium",
data = glob(["tests/data/model_weights/**"]),
srcs = ["tests/test_model_imports.py"]
)

py_test(
name = "tests/test_nested_observation_spaces",
main = "tests/test_nested_observation_spaces.py",
Expand Down
Loading

0 comments on commit 710f557

Please sign in to comment.