forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Preparatory PR for multi-agent, multi-GPU learning agent (alp…
…ha-star style) #2. (ray-project#21649)
- Loading branch information
Showing
12 changed files
with
596 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import collections | ||
import platform | ||
import random | ||
from typing import Optional | ||
|
||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer | ||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch | ||
from ray.rllib.utils.timer import TimerStat | ||
from ray.rllib.utils.typing import PolicyID, SampleBatchType | ||
|
||
|
||
class MixInMultiAgentReplayBuffer: | ||
"""This buffer adds replayed samples to a stream of new experiences. | ||
- Any newly added batch (`add_batch()`) is immediately returned upon | ||
the next `replay` call (close to on-policy) as well as being moved | ||
into the buffer. | ||
- Additionally, a certain number of old samples is mixed into the | ||
returned sample according to a given "replay ratio". | ||
- If >1 calls to `add_batch()` are made without any `replay()` calls | ||
in between, all newly added batches are returned (plus some older samples | ||
according to the "replay ratio"). | ||
Examples: | ||
# replay ratio 0.66 (2/3 replayed, 1/3 new samples): | ||
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100, | ||
... replay_ratio=0.66) | ||
>>> buffer.add_batch(<A>) | ||
>>> buffer.add_batch(<B>) | ||
>>> buffer.replay() | ||
... [<A>, <B>, <B>] | ||
>>> buffer.add_batch(<C>) | ||
>>> buffer.replay() | ||
... [<C>, <A>, <B>] | ||
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it | ||
>>> # is the newest sample | ||
>>> buffer.add_batch(<D>) | ||
>>> buffer.replay() | ||
... [<D>, <A>, <C>] | ||
# replay proportion 0.0 -> replay disabled: | ||
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0) | ||
>>> buffer.add_batch(<A>) | ||
>>> buffer.replay() | ||
... [<A>] | ||
>>> buffer.add_batch(<B>) | ||
>>> buffer.replay() | ||
... [<B>] | ||
""" | ||
|
||
def __init__(self, capacity: int, replay_ratio: float): | ||
"""Initializes MixInReplay instance. | ||
Args: | ||
capacity (int): Number of batches to store in total. | ||
replay_ratio (float): Ratio of replayed samples in the returned | ||
batches. E.g. a ratio of 0.0 means only return new samples | ||
(no replay), a ratio of 0.5 means always return newest sample | ||
plus one old one (1:1), a ratio of 0.66 means always return | ||
the newest sample plus 2 old (replayed) ones (1:2), etc... | ||
""" | ||
self.capacity = capacity | ||
self.replay_ratio = replay_ratio | ||
self.replay_proportion = None | ||
if self.replay_ratio != 1.0: | ||
self.replay_proportion = self.replay_ratio / ( | ||
1.0 - self.replay_ratio) | ||
|
||
def new_buffer(): | ||
return SimpleReplayBuffer(num_slots=capacity) | ||
|
||
self.replay_buffers = collections.defaultdict(new_buffer) | ||
|
||
# Metrics. | ||
self.add_batch_timer = TimerStat() | ||
self.replay_timer = TimerStat() | ||
self.update_priorities_timer = TimerStat() | ||
|
||
# Added timesteps over lifetime. | ||
self.num_added = 0 | ||
|
||
# Last added batch(es). | ||
self.last_added_batches = collections.defaultdict(list) | ||
|
||
def add_batch(self, batch: SampleBatchType) -> None: | ||
"""Adds a batch to the appropriate policy's replay buffer. | ||
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if | ||
it is not a MultiAgentBatch. Subsequently adds the individual policy | ||
batches to the storage. | ||
Args: | ||
batch: The batch to be added. | ||
""" | ||
# Make a copy so the replay buffer doesn't pin plasma memory. | ||
batch = batch.copy() | ||
batch = batch.as_multi_agent() | ||
|
||
with self.add_batch_timer: | ||
for policy_id, sample_batch in batch.policy_batches.items(): | ||
self.replay_buffers[policy_id].add_batch(sample_batch) | ||
self.last_added_batches[policy_id].append(sample_batch) | ||
self.num_added += batch.count | ||
|
||
def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \ | ||
Optional[SampleBatchType]: | ||
buffer = self.replay_buffers[policy_id] | ||
# Return None, if: | ||
# - Buffer empty or | ||
# - `replay_ratio` < 1.0 (new samples required in returned batch) | ||
# and no new samples to mix with replayed ones. | ||
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0 | ||
and self.replay_ratio < 1.0): | ||
return None | ||
|
||
# Mix buffer's last added batches with older replayed batches. | ||
with self.replay_timer: | ||
output_batches = self.last_added_batches[policy_id] | ||
self.last_added_batches[policy_id] = [] | ||
|
||
# No replay desired -> Return here. | ||
if self.replay_ratio == 0.0: | ||
return SampleBatch.concat_samples(output_batches) | ||
# Only replay desired -> Return a (replayed) sample from the | ||
# buffer. | ||
elif self.replay_ratio == 1.0: | ||
return buffer.replay() | ||
|
||
# Replay ratio = old / [old + new] | ||
# Replay proportion: old / new | ||
num_new = len(output_batches) | ||
replay_proportion = self.replay_proportion | ||
while random.random() < num_new * replay_proportion: | ||
replay_proportion -= 1 | ||
output_batches.append(buffer.replay()) | ||
return SampleBatch.concat_samples(output_batches) | ||
|
||
def get_host(self) -> str: | ||
"""Returns the computer's network name. | ||
Returns: | ||
The computer's networks name or an empty string, if the network | ||
name could not be determined. | ||
""" | ||
return platform.node() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import logging | ||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set | ||
|
||
import ray | ||
from ray.actor import ActorHandle | ||
from ray.rllib.utils.annotations import ExperimentalAPI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@ExperimentalAPI | ||
def asynchronous_parallel_requests( | ||
remote_requests_in_flight: DefaultDict[ActorHandle, Set[ | ||
ray.ObjectRef]], | ||
actors: List[ActorHandle], | ||
ray_wait_timeout_s: Optional[float] = None, | ||
max_remote_requests_in_flight_per_actor: int = 2, | ||
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None, | ||
remote_args: Optional[List[List[Any]]] = None, | ||
remote_kwargs: Optional[List[Dict[str, Any]]] = None, | ||
) -> Dict[ActorHandle, Any]: | ||
"""Runs parallel and asynchronous rollouts on all remote workers. | ||
May use a timeout (if provided) on `ray.wait()` and returns only those | ||
samples that could be gathered in the timeout window. Allows a maximum | ||
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight | ||
per remote actor. | ||
Alternatively to calling `actor.sample.remote()`, the user can provide a | ||
`remote_fn()`, which will be applied to the actor(s) instead. | ||
Args: | ||
remote_requests_in_flight: Dict mapping actor handles to a set of | ||
their currently-in-flight pending requests (those we expect to | ||
ray.get results for next). If you have an RLlib Trainer that calls | ||
this function, you can use its `self.remote_requests_in_flight` | ||
property here. | ||
actors: The List of ActorHandles to perform the remote requests on. | ||
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying | ||
`ray.wait()` calls. If None (default), never time out (block | ||
until at least one actor returns something). | ||
max_remote_requests_in_flight_per_actor: Maximum number of remote | ||
requests sent to each actor. 2 (default) is probably | ||
sufficient to avoid idle times between two requests. | ||
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of | ||
`actor.sample.remote()` to generate the requests. | ||
remote_args: If provided, use this list (per-actor) of lists (call | ||
args) as *args to be passed to the `remote_fn`. | ||
E.g.: actors=[A, B], | ||
remote_args=[[...] <- *args for A, [...] <- *args for B]. | ||
remote_kwargs: If provided, use this list (per-actor) of dicts | ||
(kwargs) as **kwargs to be passed to the `remote_fn`. | ||
E.g.: actors=[A, B], | ||
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. | ||
Returns: | ||
A dict mapping actor handles to the results received by sending requests | ||
to these actors. | ||
None, if no samples are ready. | ||
Examples: | ||
>>> # 2 remote rollout workers (num_workers=2): | ||
>>> batches = asynchronous_parallel_sample( | ||
... trainer.remote_requests_in_flight, | ||
... actors=trainer.workers.remote_workers(), | ||
... ray_wait_timeout_s=0.1, | ||
... remote_fn=lambda w: time.sleep(1) # sleep 1sec | ||
... ) | ||
>>> print(len(batches)) | ||
... 2 | ||
>>> # Expect a timeout to have happened. | ||
>>> batches[0] is None and batches[1] is None | ||
... True | ||
""" | ||
|
||
if remote_args is not None: | ||
assert len(remote_args) == len(actors) | ||
if remote_kwargs is not None: | ||
assert len(remote_kwargs) == len(actors) | ||
|
||
# For faster hash lookup. | ||
actor_set = set(actors) | ||
|
||
# Collect all currently pending remote requests into a single set of | ||
# object refs. | ||
pending_remotes = set() | ||
# Also build a map to get the associated actor for each remote request. | ||
remote_to_actor = {} | ||
for actor, set_ in remote_requests_in_flight.items(): | ||
# Only consider those actors' pending requests that are in | ||
# the given `actors` list. | ||
if actor in actor_set: | ||
pending_remotes |= set_ | ||
for r in set_: | ||
remote_to_actor[r] = actor | ||
|
||
# Add new requests, if possible (if | ||
# `max_remote_requests_in_flight_per_actor` setting allows it). | ||
for actor_idx, actor in enumerate(actors): | ||
# Still room for another request to this actor. | ||
if len(remote_requests_in_flight[actor]) < \ | ||
max_remote_requests_in_flight_per_actor: | ||
if remote_fn is None: | ||
req = actor.sample.remote() | ||
else: | ||
args = remote_args[actor_idx] if remote_args else [] | ||
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} | ||
req = actor.apply.remote(remote_fn, *args, **kwargs) | ||
# Add to our set to send to ray.wait(). | ||
pending_remotes.add(req) | ||
# Keep our mappings properly updated. | ||
remote_requests_in_flight[actor].add(req) | ||
remote_to_actor[req] = actor | ||
|
||
# There must always be pending remote requests. | ||
assert len(pending_remotes) > 0 | ||
pending_remote_list = list(pending_remotes) | ||
|
||
# No timeout: Block until at least one result is returned. | ||
if ray_wait_timeout_s is None: | ||
# First try to do a `ray.wait` w/o timeout for efficiency. | ||
ready, _ = ray.wait( | ||
pending_remote_list, num_returns=len(pending_remotes), timeout=0) | ||
# Nothing returned and `timeout` is None -> Fall back to a | ||
# blocking wait to make sure we can return something. | ||
if not ready: | ||
ready, _ = ray.wait(pending_remote_list, num_returns=1) | ||
# Timeout: Do a `ray.wait() call` w/ timeout. | ||
else: | ||
ready, _ = ray.wait( | ||
pending_remote_list, | ||
num_returns=len(pending_remotes), | ||
timeout=ray_wait_timeout_s) | ||
|
||
# Return empty results if nothing ready after the timeout. | ||
if not ready: | ||
return {} | ||
|
||
# Remove in-flight records for ready refs. | ||
for obj_ref in ready: | ||
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref) | ||
|
||
# Do one ray.get(). | ||
results = ray.get(ready) | ||
assert len(ready) == len(results) | ||
|
||
# Return mapping from (ready) actors to their results. | ||
ret = {} | ||
for obj_ref, result in zip(ready, results): | ||
ret[remote_to_actor[obj_ref]] = result | ||
|
||
return ret |
Oops, something went wrong.