Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] ReplayBuffer API Simple Q #22842

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
3d6befc
first draft of classes
ArturNiederfahrenhorst Feb 3, 2022
28f23d3
formatting
ArturNiederfahrenhorst Feb 4, 2022
16d8d1e
added config
ArturNiederfahrenhorst Feb 4, 2022
4988b2b
typo
ArturNiederfahrenhorst Feb 4, 2022
9c45591
Reservoir buffer sketch and new typehints for sample()
ArturNiederfahrenhorst Feb 5, 2022
15b2e04
wip, https://github.com/ray-project/ray/pull/22114\#discussion_r79971…
ArturNiederfahrenhorst Feb 5, 2022
2c6daba
wip https://github.com/ray-project/ray/pull/22114\#discussion_r799724172
ArturNiederfahrenhorst Feb 5, 2022
2d10d74
added missing docstrings
ArturNiederfahrenhorst Feb 5, 2022
83a2dcb
Partial MixInReplayBuffer rewrite with added get_state and set_state …
ArturNiederfahrenhorst Feb 5, 2022
49e75da
sven's nits
ArturNiederfahrenhorst Feb 9, 2022
9d17c4d
wip
ArturNiederfahrenhorst Feb 9, 2022
96a4250
Merge branch 'master' into ReplayBufferAPI_tests
ArturNiederfahrenhorst Feb 9, 2022
bfbc354
jungs TODO from initial ReplayBuffer PR
ArturNiederfahrenhorst Feb 10, 2022
ccacadc
first bunch of tests
ArturNiederfahrenhorst Feb 11, 2022
6afc21c
features and fixes that came with first couple of tests
ArturNiederfahrenhorst Feb 11, 2022
4e4dbe5
replay buffer and tests done
ArturNiederfahrenhorst Feb 15, 2022
95e0ee3
prioritized replay buffer and tests done
ArturNiederfahrenhorst Feb 15, 2022
f47a0a1
merge from master
ArturNiederfahrenhorst Feb 15, 2022
53f9dd8
wip
ArturNiederfahrenhorst Feb 15, 2022
5bd50ad
Apply suggestions from code review
sven1977 Feb 16, 2022
0b64d62
MultiAgentReplayBuffer and tests
ArturNiederfahrenhorst Feb 18, 2022
ee37a85
Merge remote-tracking branch 'origin/ReplayBufferAPI_tests' into Repl…
ArturNiederfahrenhorst Feb 18, 2022
c6a73e1
MultiAgentReplayBuffer better tests and warning
ArturNiederfahrenhorst Feb 19, 2022
13032ac
Added MultiAgentPrioritizedReplayBuffer and tests
ArturNiederfahrenhorst Feb 19, 2022
3da08fc
minors
ArturNiederfahrenhorst Feb 21, 2022
bf4a665
multi agent prioritized comments, fixes
ArturNiederfahrenhorst Feb 21, 2022
a7b7c3e
multi agent comments, fixes
ArturNiederfahrenhorst Feb 21, 2022
90f3eca
MultiAgentMixInReplayBuffer and tests
ArturNiederfahrenhorst Feb 21, 2022
888dca7
Reservoir Buffer and tests
ArturNiederfahrenhorst Feb 21, 2022
0fd7a63
wip
ArturNiederfahrenhorst Feb 22, 2022
db98ef3
wup
ArturNiederfahrenhorst Feb 22, 2022
2ac5916
Merge remote-tracking branch 'upstream/master' into ReplayBufferAPI_t…
ArturNiederfahrenhorst Feb 24, 2022
23f7122
fix
ArturNiederfahrenhorst Feb 24, 2022
a870ad0
Adds tests for timeslices and changes SampleBatch.rows()
ArturNiederfahrenhorst Feb 26, 2022
3537e2b
finalizes check_sample_batches and test_multi_agent_batch
ArturNiederfahrenhorst Feb 26, 2022
98abf64
finalizes PR
ArturNiederfahrenhorst Feb 26, 2022
21c4b4b
test_rows
ArturNiederfahrenhorst Feb 27, 2022
7a8d0f3
format.sh
ArturNiederfahrenhorst Mar 4, 2022
85aaaad
Adds tests to CI, comments
ArturNiederfahrenhorst Mar 4, 2022
aeae356
Merge remote-tracking branch 'upstream/master' into fix-multiagentbat…
ArturNiederfahrenhorst Mar 4, 2022
1d18245
small fix in mixin buffer, format.sh
ArturNiederfahrenhorst Mar 5, 2022
e57ce01
Merge remote-tracking branch 'upstream/master' into ReplayBufferAPI_t…
ArturNiederfahrenhorst Mar 5, 2022
40dfcac
Merge branch 'fix-multiagentbatch-timeslices' into ReplayBufferAPI_tests
ArturNiederfahrenhorst Mar 5, 2022
e7b0ace
make mixin sequence sampling test less flaky
ArturNiederfahrenhorst Mar 5, 2022
56276c5
include tests in rllib BUILD file before moving to critical path
ArturNiederfahrenhorst Mar 6, 2022
1006364
Merge branch 'ReplayBuffersAPI_config' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 6, 2022
7118f8e
initial
ArturNiederfahrenhorst Mar 6, 2022
9d84ec0
format
ArturNiederfahrenhorst Mar 6, 2022
3b6dd74
format
ArturNiederfahrenhorst Mar 6, 2022
7f7b602
format
ArturNiederfahrenhorst Mar 6, 2022
cf9fd38
fix conflicting test names
ArturNiederfahrenhorst Mar 7, 2022
67265d4
compatibility with APPO and CQL, committing to see if CI is satisfied
ArturNiederfahrenhorst Mar 7, 2022
da50351
wip
ArturNiederfahrenhorst Mar 7, 2022
2c3a0d1
fix BUILD file
ArturNiederfahrenhorst Mar 7, 2022
7c837d9
Merge branch 'ReplayBufferAPI_tests' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 7, 2022
3cf1405
wip
ArturNiederfahrenhorst Mar 7, 2022
ed8e5d1
format
ArturNiederfahrenhorst Mar 8, 2022
cd92285
Merge branch 'ReplayBufferAPI_tests' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 8, 2022
10772f4
wip
ArturNiederfahrenhorst Mar 8, 2022
d93c217
format
ArturNiederfahrenhorst Mar 8, 2022
01d356d
Merge branch 'ReplayBufferAPI_tests' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 8, 2022
ae0d2ac
Merge remote-tracking branch 'upstream' into ReplayBufferAPI_tests
ArturNiederfahrenhorst Mar 8, 2022
aa3336d
Merge branch 'ReplayBufferAPI_tests' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 8, 2022
d69ca75
adds train iter fn to simple q, removes possible legacy parameters, r…
ArturNiederfahrenhorst Mar 9, 2022
b130323
format
ArturNiederfahrenhorst Mar 11, 2022
2d8cf0c
Merge remote-tracking branch 'upstream/master' into ReplayBufferAPI_S…
ArturNiederfahrenhorst Mar 11, 2022
b6f1620
wip
ArturNiederfahrenhorst Mar 11, 2022
7645541
wip
ArturNiederfahrenhorst Mar 11, 2022
c8d85e4
wip
ArturNiederfahrenhorst Mar 12, 2022
ef042a6
wip
ArturNiederfahrenhorst Mar 12, 2022
ed606fd
wip
ArturNiederfahrenhorst Mar 13, 2022
193f94f
Merge branch 'master' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 13, 2022
d3060f9
wip
ArturNiederfahrenhorst Mar 13, 2022
8a678a9
fix faulty apex config
ArturNiederfahrenhorst Mar 14, 2022
84e1a84
fix apex ddpg
ArturNiederfahrenhorst Mar 14, 2022
b1fd2db
undo ddpg changes
ArturNiederfahrenhorst Mar 14, 2022
57f1903
wip
ArturNiederfahrenhorst Mar 14, 2022
7a618b8
fix fauly SAC configuration in test
ArturNiederfahrenhorst Mar 16, 2022
4fc1cdf
fix replay burn in for r2d2
ArturNiederfahrenhorst Mar 16, 2022
5bc7fc2
fixes burn_in and providing old MultiAgentReplayBuffer as a class
ArturNiederfahrenhorst Mar 17, 2022
b67c7ce
wip
ArturNiederfahrenhorst Mar 17, 2022
ab9d548
Merge branch 'master' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 17, 2022
f9c520c
use SYNCH_WORKER_WEIGHTS_TIMER
ArturNiederfahrenhorst Mar 17, 2022
d94c7af
wip
ArturNiederfahrenhorst Mar 17, 2022
0732d8d
simple Q is learning
ArturNiederfahrenhorst Mar 18, 2022
1d7ed47
Comments and docstrings
ArturNiederfahrenhorst Mar 18, 2022
32e8558
Merge branch 'master' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 18, 2022
bf5cbf1
Sven's comments
ArturNiederfahrenhorst Mar 22, 2022
a8f3b84
Merge branch 'master' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 22, 2022
a2d5851
jun's feedback
ArturNiederfahrenhorst Mar 23, 2022
9032c5d
Sven's nits
ArturNiederfahrenhorst Mar 23, 2022
2a832ae
Merge branch 'master' into ReplayBufferAPI_Simple_Q
ArturNiederfahrenhorst Mar 27, 2022
0277416
svents nits
ArturNiederfahrenhorst Mar 28, 2022
bd51df8
Sven's comments
ArturNiederfahrenhorst Mar 28, 2022
e4fab6c
Merge branch 'master' of https://github.com/ray-project/ray into Repl…
sven1977 Mar 29, 2022
b8781a5
Merge remote-tracking branch 'artur/ReplayBufferAPI_Simple_Q' into Re…
sven1977 Mar 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions release/rllib_tests/learning_tests/hard_learning_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ apex-breakoutnoframeskip-v4:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8
Expand Down Expand Up @@ -327,8 +325,6 @@ dqn-breakoutnoframeskip-v4:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.5
timesteps_per_iteration: 10000

Expand Down
2 changes: 0 additions & 2 deletions release/rllib_tests/performance_tests/performance_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ apex-breakoutnoframeskip-v4:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config.
"replay_buffer_config": None,
"no_local_replay_buffer": True,
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
Expand Down
4 changes: 0 additions & 4 deletions rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Time steps over which the beta parameter is annealed.
"prioritized_replay_beta_annealing_timesteps": 20000,
# Final value of beta
"final_prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
Expand Down
8 changes: 4 additions & 4 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": None,
"no_local_replay_buffer": True,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
Expand Down Expand Up @@ -157,9 +157,9 @@ def execution_plan(
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
config["replay_buffer_config"]["prioritized_replay_alpha"],
config["replay_buffer_config"]["prioritized_replay_beta"],
config["replay_buffer_config"]["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
]
Expand Down
49 changes: 37 additions & 12 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.deprecation import DEPRECATED_VALUE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,19 +65,37 @@
# N-step Q learning
"n_step": 1,

# === Prioritized replay buffer ===
# If True prioritized replay buffer will be used.
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
# Prioritized replay is here since this algo uses the old replay
# buffer api
"prioritized_replay": True,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Final value of beta (by default, we use constant beta=0.4).
"final_prioritized_replay_beta": 0.4,
# Time steps over which the beta parameter is annealed.
"prioritized_replay_beta_annealing_timesteps": 20000,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"prioritized_replay_alpha": 0.6,
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
# Warnings will be created if:
# - This is True AND restoring from a checkpoint that contains no buffer
# data.
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this work? I thought data in RB have already been post-processed. So these samples should all have the necessary state inputs for recurrent models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think state inputs don't live in SampleBatches when they are stored in replay buffers. Recurrent state is passed through the forwad() method of the ModelV2 API and is also initialized by the ModelV2 object via get_initial_state().
This should be taken into consideration on the connector design, right?. @gjoliver

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I need to double check the code. it seems like the API for adding a SampleBatch assumes that the batch contains a full episode, and it will slice it up according to replay_sequence_length, and store multiple smaller batches as a result.
am I reading it right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I read through everything. our codebase is really a mess.
I believe SampleBatch does carry all the state_in/out columns. if you look at timeslice_along_seq_lens_with_overlap(), it handles the recurrent states correctly.

all those complicated state building logics in Sampler and SimpleListCollector are actually just for rollout. I feel like we should be able to clean up tons of CPU heavy stuff that doesn't do anything today.

btw, if ReplayBuffer is handling the batching of RNN states, how does RNN work for agents like PG that doesn't use ReplayBuffer???

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested it out. it simply takes the raw batch with all the state_in and state_out etc.
so still runs fine. 👌



# Callback to run before learning on a multi-agent batch of
# experiences.
Expand All @@ -102,6 +121,12 @@
# === Parallelism ===
# Whether to compute priorities on workers.
"worker_side_prioritization": False,

# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)
Expand Down
12 changes: 8 additions & 4 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,14 @@ def postprocess_nstep_and_prio(
batch[SampleBatch.DONES],
batch[PRIO_WEIGHTS],
)
new_priorities = (
np.abs(convert_to_numpy(td_errors))
+ policy.config["prioritized_replay_eps"]
)
# Retain compatibility with old-style Replay args
epsilon = policy.config.get("replay_buffer_config", {}).get(
"prioritized_replay_eps"
) or policy.config.get("prioritized_replay_eps")
if epsilon is None:
raise ValueError("prioritized_replay_eps not defined in config.")

new_priorities = np.abs(convert_to_numpy(td_errors)) + epsilon
batch[PRIO_WEIGHTS] = new_priorities

return batch
Expand Down
19 changes: 19 additions & 0 deletions rllib/agents/dqn/r2d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@
# Batch mode must be complete_episodes.
"batch_mode": "complete_episodes",

# === Replay buffer ===
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
},
# If True, assume a zero-initialized state input (no matter where in
# the episode the sequence is located).
# If False, store the initial states along with each SampleBatch, use
Expand Down Expand Up @@ -66,6 +79,12 @@

# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 2500,

# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)
Expand Down
109 changes: 102 additions & 7 deletions rllib/agents/dqn/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,40 @@
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
MultiGPUTrainOneStep,
TrainOneStep,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.execution.train_ops import (
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED,
)
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
)
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,9 +85,18 @@
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
# Deprecated for Simple Q because of new ReplayBuffer API
# Use MultiAgentPrioritizedReplayBuffer for prioritization.
"prioritized_replay": DEPRECATED_VALUE,
"replay_buffer_config": {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
# The number of contiguous environment steps to replay at once. This
# may be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
Expand All @@ -76,9 +106,6 @@
# - This is False AND restoring from a checkpoint that does contain
# buffer data.
"store_buffer_in_checkpoints": False,
# The number of contiguous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,

# === Optimization ===
# Learning rate for adam optimizer
Expand Down Expand Up @@ -108,6 +135,12 @@
"num_workers": 0,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,

# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on
Expand Down Expand Up @@ -139,7 +172,9 @@ def validate_config(self, config: TrainerConfigDict) -> None:
" used at the same time!"
)

if config.get("prioritized_replay"):
if config.get("prioritized_replay") or config.get(
"replay_buffer_config", {}
).get("prioritized_replay"):
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
Expand Down Expand Up @@ -215,3 +250,63 @@ def execution_plan(workers, config, **kwargs):
)

return StandardMetricsReporting(train_op, workers, config)

@ExperimentalAPI
def training_iteration(self) -> ResultDict:
"""Simple Q training iteration function.

Simple Q consists of the following steps:
- (1) Sample (MultiAgentBatch) from workers...
- (2) Store new samples in replay buffer.
- (3) Sample training batch (MultiAgentBatch) from replay buffer.
- (4) Learn on training batch.
- (5) Update target network every target_network_update_freq steps.
- (6) Return all collected metrics for the iteration.

Returns:
The results dict from executing the training iteration.
"""
batch_size = self.config["train_batch_size"]
local_worker = self.workers.local_worker()

# (1) Sample (MultiAgentBatch) from workers
new_sample_batches = synchronous_parallel_sample(self.workers)
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved

for s in new_sample_batches:
# Update counters
self._counters[NUM_ENV_STEPS_SAMPLED] += len(s)
self._counters[NUM_AGENT_STEPS_SAMPLED] += (
len(s) if isinstance(s, SampleBatch) else s.agent_steps()
)
# (2) Store new samples in replay buffer
self.local_replay_buffer.add(s)

# (3) Sample training batch (MultiAgentBatch) from replay buffer.
train_batch = self.local_replay_buffer.sample(batch_size)

# (4) Learn on training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)

# (5) Update target network every target_network_update_freq steps
cur_ts = self._counters[NUM_ENV_STEPS_SAMPLED]
last_update = self._counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.config["target_network_update_freq"]:
to_update = local_worker.get_policies_to_train()
local_worker.foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts

# Update remote workers' weights after learning on local worker
if self.workers.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights()

# (6) Return all collected metrics for the iteration.
return train_results
2 changes: 0 additions & 2 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"prioritized_replay_beta_annealing_timesteps": 20000,
"final_prioritized_replay_beta": 0.4,
# Whether to LZ4 compress observations
"compress_observations": False,

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_sac_compilation(self):
# If we use default buffer size (1e6), the buffer will take up
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
# available system memory (8.34816 GB).
config["buffer_size"] = 40000
config["replay_buffer_config"]["capacity"] = 40000
# Test with saved replay buffer.
config["store_buffer_in_checkpoints"] = True
num_iterations = 1
Expand Down
Loading