Skip to content

Commit

Permalink
[RLlib] Preparatory PR for multi-agent, multi-GPU learning agent (alp…
Browse files Browse the repository at this point in the history
…ha-star style) #2. (ray-project#21649)
  • Loading branch information
sven1977 authored Jan 27, 2022
1 parent 8ebc50f commit ee41800
Show file tree
Hide file tree
Showing 12 changed files with 596 additions and 170 deletions.
146 changes: 146 additions & 0 deletions rllib/execution/buffers/mixin_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import collections
import platform
import random
from typing import Optional

from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType


class MixInMultiAgentReplayBuffer:
"""This buffer adds replayed samples to a stream of new experiences.
- Any newly added batch (`add_batch()`) is immediately returned upon
the next `replay` call (close to on-policy) as well as being moved
into the buffer.
- Additionally, a certain number of old samples is mixed into the
returned sample according to a given "replay ratio".
- If >1 calls to `add_batch()` are made without any `replay()` calls
in between, all newly added batches are returned (plus some older samples
according to the "replay ratio").
Examples:
# replay ratio 0.66 (2/3 replayed, 1/3 new samples):
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100,
... replay_ratio=0.66)
>>> buffer.add_batch(<A>)
>>> buffer.add_batch(<B>)
>>> buffer.replay()
... [<A>, <B>, <B>]
>>> buffer.add_batch(<C>)
>>> buffer.replay()
... [<C>, <A>, <B>]
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it
>>> # is the newest sample
>>> buffer.add_batch(<D>)
>>> buffer.replay()
... [<D>, <A>, <C>]
# replay proportion 0.0 -> replay disabled:
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0)
>>> buffer.add_batch(<A>)
>>> buffer.replay()
... [<A>]
>>> buffer.add_batch(<B>)
>>> buffer.replay()
... [<B>]
"""

def __init__(self, capacity: int, replay_ratio: float):
"""Initializes MixInReplay instance.
Args:
capacity (int): Number of batches to store in total.
replay_ratio (float): Ratio of replayed samples in the returned
batches. E.g. a ratio of 0.0 means only return new samples
(no replay), a ratio of 0.5 means always return newest sample
plus one old one (1:1), a ratio of 0.66 means always return
the newest sample plus 2 old (replayed) ones (1:2), etc...
"""
self.capacity = capacity
self.replay_ratio = replay_ratio
self.replay_proportion = None
if self.replay_ratio != 1.0:
self.replay_proportion = self.replay_ratio / (
1.0 - self.replay_ratio)

def new_buffer():
return SimpleReplayBuffer(num_slots=capacity)

self.replay_buffers = collections.defaultdict(new_buffer)

# Metrics.
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()

# Added timesteps over lifetime.
self.num_added = 0

# Last added batch(es).
self.last_added_batches = collections.defaultdict(list)

def add_batch(self, batch: SampleBatchType) -> None:
"""Adds a batch to the appropriate policy's replay buffer.
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
it is not a MultiAgentBatch. Subsequently adds the individual policy
batches to the storage.
Args:
batch: The batch to be added.
"""
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
batch = batch.as_multi_agent()

with self.add_batch_timer:
for policy_id, sample_batch in batch.policy_batches.items():
self.replay_buffers[policy_id].add_batch(sample_batch)
self.last_added_batches[policy_id].append(sample_batch)
self.num_added += batch.count

def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
Optional[SampleBatchType]:
buffer = self.replay_buffers[policy_id]
# Return None, if:
# - Buffer empty or
# - `replay_ratio` < 1.0 (new samples required in returned batch)
# and no new samples to mix with replayed ones.
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0
and self.replay_ratio < 1.0):
return None

# Mix buffer's last added batches with older replayed batches.
with self.replay_timer:
output_batches = self.last_added_batches[policy_id]
self.last_added_batches[policy_id] = []

# No replay desired -> Return here.
if self.replay_ratio == 0.0:
return SampleBatch.concat_samples(output_batches)
# Only replay desired -> Return a (replayed) sample from the
# buffer.
elif self.replay_ratio == 1.0:
return buffer.replay()

# Replay ratio = old / [old + new]
# Replay proportion: old / new
num_new = len(output_batches)
replay_proportion = self.replay_proportion
while random.random() < num_new * replay_proportion:
replay_proportion -= 1
output_batches.append(buffer.replay())
return SampleBatch.concat_samples(output_batches)

def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()
11 changes: 8 additions & 3 deletions rllib/execution/buffers/multi_agent_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import platform
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np
import ray
Expand All @@ -13,7 +13,7 @@
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.util.iter import ParallelIteratorWorker


Expand Down Expand Up @@ -195,7 +195,7 @@ def add_batch(self, batch: SampleBatchType) -> None:
time_slice, weight=weight)
self.num_added += batch.count

def replay(self) -> SampleBatchType:
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
"""If this buffer was given a fake batch, return it, otherwise return
a MultiAgentBatch with samples.
"""
Expand All @@ -211,8 +211,13 @@ def replay(self) -> SampleBatchType:
# Lockstep mode: Sample from all policies at the same time an
# equal amount of steps.
if self.replay_mode == "lockstep":
assert policy_id is None, \
"`policy_id` specifier not allowed in `locksetp` mode!"
return self.replay_buffers[_ALL_POLICIES].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
elif policy_id is not None:
return self.replay_buffers[policy_id].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
else:
samples = {}
for policy_id, replay_buffer in self.replay_buffers.items():
Expand Down
22 changes: 16 additions & 6 deletions rllib/execution/buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,25 @@ def add(self, item: SampleBatchType, weight: float) -> None:

@DeveloperAPI
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
"""Sample a batch of experiences.
"""Sample a batch of size `num_items` from this buffer.
If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.
Args:
num_items: Number of items to sample from this buffer.
beta: This is ignored (only used by prioritized replay buffers).
beta: The prioritized replay beta value. Only relevant if this
ReplayBuffer is a PrioritizedReplayBuffer.
Returns:
Concatenated batch of items.
"""
idxes = [
random.randint(0,
len(self._storage) - 1) for _ in range(num_items)
]
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None

idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
sample = self._encode_sample(idxes)
# Update our timesteps counters.
self._num_timesteps_sampled += len(sample)
Expand Down Expand Up @@ -282,6 +288,10 @@ def sample(self, num_items: int, beta: float) -> SampleBatchType:
"batch_indexes" fields denoting IS of each sampled
transition and original idxes in buffer of sampled experiences.
"""
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None

assert beta >= 0.0

idxes = self._sample_proportional(num_items)
Expand Down
152 changes: 152 additions & 0 deletions rllib/execution/parallel_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import logging
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set

import ray
from ray.actor import ActorHandle
from ray.rllib.utils.annotations import ExperimentalAPI

logger = logging.getLogger(__name__)


@ExperimentalAPI
def asynchronous_parallel_requests(
remote_requests_in_flight: DefaultDict[ActorHandle, Set[
ray.ObjectRef]],
actors: List[ActorHandle],
ray_wait_timeout_s: Optional[float] = None,
max_remote_requests_in_flight_per_actor: int = 2,
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None,
remote_args: Optional[List[List[Any]]] = None,
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Dict[ActorHandle, Any]:
"""Runs parallel and asynchronous rollouts on all remote workers.
May use a timeout (if provided) on `ray.wait()` and returns only those
samples that could be gathered in the timeout window. Allows a maximum
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
per remote actor.
Alternatively to calling `actor.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the actor(s) instead.
Args:
remote_requests_in_flight: Dict mapping actor handles to a set of
their currently-in-flight pending requests (those we expect to
ray.get results for next). If you have an RLlib Trainer that calls
this function, you can use its `self.remote_requests_in_flight`
property here.
actors: The List of ActorHandles to perform the remote requests on.
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
`ray.wait()` calls. If None (default), never time out (block
until at least one actor returns something).
max_remote_requests_in_flight_per_actor: Maximum number of remote
requests sent to each actor. 2 (default) is probably
sufficient to avoid idle times between two requests.
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
`actor.sample.remote()` to generate the requests.
remote_args: If provided, use this list (per-actor) of lists (call
args) as *args to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_args=[[...] <- *args for A, [...] <- *args for B].
remote_kwargs: If provided, use this list (per-actor) of dicts
(kwargs) as **kwargs to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
Returns:
A dict mapping actor handles to the results received by sending requests
to these actors.
None, if no samples are ready.
Examples:
>>> # 2 remote rollout workers (num_workers=2):
>>> batches = asynchronous_parallel_sample(
... trainer.remote_requests_in_flight,
... actors=trainer.workers.remote_workers(),
... ray_wait_timeout_s=0.1,
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
... )
>>> print(len(batches))
... 2
>>> # Expect a timeout to have happened.
>>> batches[0] is None and batches[1] is None
... True
"""

if remote_args is not None:
assert len(remote_args) == len(actors)
if remote_kwargs is not None:
assert len(remote_kwargs) == len(actors)

# For faster hash lookup.
actor_set = set(actors)

# Collect all currently pending remote requests into a single set of
# object refs.
pending_remotes = set()
# Also build a map to get the associated actor for each remote request.
remote_to_actor = {}
for actor, set_ in remote_requests_in_flight.items():
# Only consider those actors' pending requests that are in
# the given `actors` list.
if actor in actor_set:
pending_remotes |= set_
for r in set_:
remote_to_actor[r] = actor

# Add new requests, if possible (if
# `max_remote_requests_in_flight_per_actor` setting allows it).
for actor_idx, actor in enumerate(actors):
# Still room for another request to this actor.
if len(remote_requests_in_flight[actor]) < \
max_remote_requests_in_flight_per_actor:
if remote_fn is None:
req = actor.sample.remote()
else:
args = remote_args[actor_idx] if remote_args else []
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
req = actor.apply.remote(remote_fn, *args, **kwargs)
# Add to our set to send to ray.wait().
pending_remotes.add(req)
# Keep our mappings properly updated.
remote_requests_in_flight[actor].add(req)
remote_to_actor[req] = actor

# There must always be pending remote requests.
assert len(pending_remotes) > 0
pending_remote_list = list(pending_remotes)

# No timeout: Block until at least one result is returned.
if ray_wait_timeout_s is None:
# First try to do a `ray.wait` w/o timeout for efficiency.
ready, _ = ray.wait(
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
# Nothing returned and `timeout` is None -> Fall back to a
# blocking wait to make sure we can return something.
if not ready:
ready, _ = ray.wait(pending_remote_list, num_returns=1)
# Timeout: Do a `ray.wait() call` w/ timeout.
else:
ready, _ = ray.wait(
pending_remote_list,
num_returns=len(pending_remotes),
timeout=ray_wait_timeout_s)

# Return empty results if nothing ready after the timeout.
if not ready:
return {}

# Remove in-flight records for ready refs.
for obj_ref in ready:
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)

# Do one ray.get().
results = ray.get(ready)
assert len(ready) == len(results)

# Return mapping from (ready) actors to their results.
ret = {}
for obj_ref, result in zip(ready, results):
ret[remote_to_actor[obj_ref]] = result

return ret
Loading

0 comments on commit ee41800

Please sign in to comment.