Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] IMPALA on new API stack (w/ EnvRunner- and ConnectorV2 APIs). #42085

Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
409 commits
Select commit Hold shift + click to select a range
e34c1ff
wip
sven1977 Dec 11, 2023
ccdd4e3
wip
sven1977 Dec 11, 2023
c43f6d4
- Changed ConnectorV2 API
sven1977 Dec 11, 2023
d91576c
merge
sven1977 Dec 11, 2023
72620f9
multi-GPU torch DDP fix
sven1977 Dec 11, 2023
d3a40ea
wip
sven1977 Dec 12, 2023
d5e2150
wip
sven1977 Dec 13, 2023
cea44b9
wip
sven1977 Dec 13, 2023
45dede9
wip
sven1977 Dec 13, 2023
05a01d5
Learns Atari Pong in ~6min on 8GPUs and 96CPUs on new stack w/ EnvRun…
sven1977 Dec 13, 2023
3c335aa
Merge branch 'master' into replace_learner_hps_with_algo_config
sven1977 Dec 13, 2023
0452070
Merge branch 'master' of https://github.com/ray-project/ray into repl…
sven1977 Dec 13, 2023
0213870
wip
sven1977 Dec 13, 2023
728bdec
Merge remote-tracking branch 'origin/replace_learner_hps_with_algo_co…
sven1977 Dec 13, 2023
a15bd29
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 13, 2023
5509640
wip
sven1977 Dec 13, 2023
31a6e5c
Merge branch 'master' of https://github.com/ray-project/ray into repl…
sven1977 Dec 14, 2023
765f252
wip
sven1977 Dec 14, 2023
4da7b8c
wip
sven1977 Dec 14, 2023
6b7978a
LINT
sven1977 Dec 14, 2023
0a5380e
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 14, 2023
8fc5056
wip
sven1977 Dec 14, 2023
d88fea8
Merge branch 'master' into replace_learner_hps_with_algo_config
sven1977 Dec 14, 2023
9e5ce8f
wip
sven1977 Dec 14, 2023
7e081ef
Merge remote-tracking branch 'origin/replace_learner_hps_with_algo_co…
sven1977 Dec 14, 2023
f942698
LINT
sven1977 Dec 14, 2023
9bfbef6
Merge branch 'master' of https://github.com/ray-project/ray into repl…
sven1977 Dec 14, 2023
c2317cc
wip
sven1977 Dec 14, 2023
5b9556d
Merge branch 'master' of https://github.com/ray-project/ray into repl…
sven1977 Dec 14, 2023
0286340
fix
sven1977 Dec 15, 2023
6e868be
wip
sven1977 Dec 15, 2023
55170b0
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 15, 2023
1ce5658
- new framestacking connectorV2 example script runs properly
sven1977 Dec 15, 2023
4f18ce6
- fix new framestacking connectorV2 example script runs properly (w/o…
sven1977 Dec 15, 2023
73cbc63
wip
sven1977 Dec 15, 2023
292f0bb
fix
sven1977 Dec 15, 2023
a0364e1
- Learns Pong Atari on 1 GPU with the new frame stacking connector.
sven1977 Dec 15, 2023
594a0f0
wip
sven1977 Dec 15, 2023
44281be
lower lr in example script a little
sven1977 Dec 15, 2023
764c18b
Merge branch 'master' of https://github.com/ray-project/ray into repl…
sven1977 Dec 15, 2023
aff98fe
Merge branch 'master' of https://github.com/ray-project/ray into repl…
sven1977 Dec 16, 2023
a7a1a5b
take out strangely failing test case
sven1977 Dec 16, 2023
b9e9914
LINT
sven1977 Dec 16, 2023
2e67f5d
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Dec 16, 2023
1d9cdf9
Merge branch 'replace_learner_hps_with_algo_config' into env_runner_s…
sven1977 Dec 16, 2023
1b29d53
LINT
sven1977 Dec 16, 2023
1abb907
merge
sven1977 Dec 16, 2023
dd86623
wip
sven1977 Dec 16, 2023
add5832
wip
sven1977 Dec 16, 2023
5e4564d
- still learning Atari Pong in ~6min on 8GPUs and 95 workers (just li…
sven1977 Dec 16, 2023
870cb45
wip: MeanStdFilter connectorV2 stuff
sven1977 Dec 18, 2023
347bec1
wip: MeanStdFilter connectorV2 stuff
sven1977 Dec 18, 2023
b3095c3
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Dec 18, 2023
37327e7
wip
sven1977 Dec 19, 2023
7e807ff
merge
sven1977 Dec 21, 2023
e139c24
wip
sven1977 Dec 22, 2023
bedd4fa
merge
sven1977 Dec 22, 2023
d0f40ba
IMPALA learns CartPole-v1 :)
sven1977 Dec 23, 2023
ee64c86
fix bug in SampleBatch slice (INFOs dropped from original sample batc…
sven1977 Dec 24, 2023
63c286e
wip
sven1977 Dec 26, 2023
48658ff
wip
sven1977 Dec 27, 2023
3941611
New stack throughput w/o learning step (8 env runners)
sven1977 Dec 27, 2023
2322ea3
wip
sven1977 Dec 28, 2023
b44133b
fixes
sven1977 Dec 28, 2023
ff79dd2
wip
sven1977 Jan 2, 2024
0fe9e06
merge
sven1977 Jan 5, 2024
90daa55
CartPole-v1 learning w/ 2 CPU learners
sven1977 Jan 8, 2024
1693d52
wip
sven1977 Jan 9, 2024
a1c3050
wip
sven1977 Jan 9, 2024
e22b0e3
wip
sven1977 Jan 10, 2024
b34d330
wip
sven1977 Jan 10, 2024
43f3826
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jan 10, 2024
59e4ed6
CartPole-v1 learnt properly:
sven1977 Jan 10, 2024
be39b1e
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Jan 10, 2024
8f6ac5c
APPO learns CartPole-v1 much better and faster than old stack.
sven1977 Jan 11, 2024
97dcb45
wip
sven1977 Jan 11, 2024
3bfd5d2
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Jan 11, 2024
3825475
fixes
sven1977 Jan 11, 2024
2ca006d
fixes
sven1977 Jan 11, 2024
252f7a3
fixes
sven1977 Jan 11, 2024
e7bc45b
fixes
sven1977 Jan 11, 2024
e3ba317
fixes
sven1977 Jan 11, 2024
854b87a
merge
sven1977 Jan 11, 2024
b3e7b85
fixes
sven1977 Jan 11, 2024
938d02f
CartPole learning broke completely. Commit (and the merge one before …
sven1977 Jan 11, 2024
948c4a3
CartPole learning broke completely. Still bisecting
sven1977 Jan 11, 2024
b9fec59
fix
sven1977 Jan 11, 2024
369070a
CartPole-v1 learning again!
sven1977 Jan 11, 2024
d56a11d
wip
sven1977 Jan 11, 2024
8d4ba0a
fixes
sven1977 Jan 12, 2024
ca3782f
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Jan 12, 2024
90c91ba
fixes
sven1977 Jan 12, 2024
0039b13
fixes
sven1977 Jan 12, 2024
4605dcb
fixes
sven1977 Jan 12, 2024
3583e80
Merge branch 'master' of https://github.com/ray-project/ray into env_…
sven1977 Jan 12, 2024
af25485
fixes
sven1977 Jan 12, 2024
85c5e87
Merge branch 'env_runner_support_connectors_06_small_changes_on_env_r…
sven1977 Jan 12, 2024
c0071be
wip
sven1977 Jan 12, 2024
005414d
merge
sven1977 Jan 15, 2024
b9ecfc8
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jan 18, 2024
2c88209
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Feb 28, 2024
5488859
wip
sven1977 Feb 28, 2024
7788110
wip
sven1977 Feb 29, 2024
9e068de
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Feb 29, 2024
83c2af1
wip
sven1977 Mar 1, 2024
1296812
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 1, 2024
dfca0a6
wip
sven1977 Mar 1, 2024
efddbb0
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 4, 2024
5f1aeaa
wip
sven1977 Mar 4, 2024
0c3c92d
merge
sven1977 Mar 11, 2024
c5ece2c
CartPole-v1 Impala learning faster than old API stack
sven1977 Mar 12, 2024
0af40c9
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 12, 2024
c3432ae
wip
sven1977 Mar 12, 2024
7fd73fa
fix
sven1977 Mar 12, 2024
1703291
LINT
sven1977 Mar 13, 2024
ac6d613
Merge branch 'master' of https://github.com/ray-project/ray into depr…
sven1977 Mar 13, 2024
7ef5bd9
change max-in-flight to 1 on both samplers and learners
sven1977 Mar 13, 2024
26d868a
wip
sven1977 Mar 13, 2024
bbc9f2a
wip
sven1977 Mar 15, 2024
b02d34d
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 15, 2024
1078fa3
add learner block for >1 GPUs (only for max-in-flight=1 thus far!)
sven1977 Mar 15, 2024
7534b74
wip
sven1977 Mar 16, 2024
25cdfb4
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 16, 2024
0368ed8
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 20, 2024
eafef4a
wip
sven1977 Mar 20, 2024
f3bdc75
wip
sven1977 Mar 20, 2024
758d0da
wip
sven1977 Mar 20, 2024
5dd5ed6
wip
sven1977 Mar 20, 2024
96fcfd2
wip
sven1977 Mar 20, 2024
83962c2
wip
sven1977 Mar 20, 2024
b8249c2
wip
sven1977 Mar 20, 2024
de97a25
wip
sven1977 Mar 20, 2024
7a448d9
wip
sven1977 Mar 20, 2024
d464c51
wip
sven1977 Mar 20, 2024
89be8bf
wip
sven1977 Mar 20, 2024
3e4a898
wip
sven1977 Mar 20, 2024
f586553
wip
sven1977 Mar 20, 2024
7a88bfc
wip
sven1977 Mar 20, 2024
fa2c9dc
wip
sven1977 Mar 20, 2024
c12dba4
wip
sven1977 Mar 20, 2024
55faaa1
Reactivate training.
sven1977 Mar 21, 2024
d1d4c70
wip
sven1977 Mar 21, 2024
3588bb6
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Mar 21, 2024
8527933
merge
sven1977 May 3, 2024
c0b1963
wip
sven1977 May 3, 2024
f8ea373
wip
sven1977 May 4, 2024
6db6784
wip
sven1977 May 4, 2024
bffaebb
wip
sven1977 May 4, 2024
239e85e
Merge branch 'fix_async_vector_envs_in_sa_env_runner' into appo_on_ne…
sven1977 May 4, 2024
ddac014
IMPALA CartPole on 2 CPU Learners learning.
sven1977 May 6, 2024
4dc2d6b
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 May 6, 2024
c6f1138
cleanups
sven1977 May 6, 2024
fd99b97
remove additional_update from IMPALA (thus far w/o replacement) -> ne…
sven1977 May 7, 2024
3b8f327
added functionality for additional_update to `update_from_..` in IMPA…
sven1977 May 7, 2024
087e12f
Fix async update mechanism in LearnerGroup to NOT be restricted to ha…
sven1977 May 7, 2024
6687c0e
debug
sven1977 May 7, 2024
62a5dcb
add more timer stats
sven1977 May 7, 2024
3f42f8a
Fix metrics bug with tensor in the results dict arriving at Algo. -> …
sven1977 May 7, 2024
5f73cba
fix
sven1977 May 7, 2024
f088717
wip
sven1977 May 7, 2024
f62072b
minor fixes:
sven1977 May 7, 2024
32cf3f4
more enhancements (after finding out that GPU loader thread is a bott…
sven1977 May 7, 2024
87a6d9a
- Use pin_memory() on all tensors created on CPU (before .to("cuda") …
sven1977 May 8, 2024
2c3ec52
- Fixes to make pinned memory pre-loading (non-blocking) work.
sven1977 May 8, 2024
a561d77
set default num threads back to 16
sven1977 May 8, 2024
eb7768e
debug bypassing LearnerConnector
sven1977 May 8, 2024
293b43b
better metrics and a smaller CNN model for impala-pong-fast.
sven1977 May 8, 2024
1d9c126
fix
sven1977 May 8, 2024
06b20a6
fix
sven1977 May 8, 2024
bb6c67e
fix
sven1977 May 8, 2024
97dd291
fix
sven1977 May 8, 2024
80a2a88
fix
sven1977 May 8, 2024
c12e554
- HACK-FIX too-low env-steps-trained counts. We currently only return…
sven1977 May 8, 2024
216dfa0
Make learner queue LIFO and limit its size (via config learner_queue_…
sven1977 May 8, 2024
4667d16
- Make learner queue deque.
sven1977 May 9, 2024
3ed18ec
- Fix bug in Metrics.log_n_dicts (if key does NOT exist yet in self, …
sven1977 May 10, 2024
3acd8e7
debug changes
sven1977 May 10, 2024
858798a
fixed: NUM_ENV_STEPS_SAMPLED_LIFETIME >> NUM_ENV_STEPS_TRAINED_LIFETI…
sven1977 May 10, 2024
f62d48a
- Seems to be linearly scalable now on GPU-axis wrt `NUM_ENV_STEPS_TR…
sven1977 May 10, 2024
8e14efc
merge
sven1977 May 10, 2024
3dc44de
LINT
sven1977 May 10, 2024
a354739
fix
sven1977 May 10, 2024
1847796
fix
sven1977 May 10, 2024
cf7482b
fix
sven1977 May 10, 2024
23fd5aa
naming fixes: envs_per_worker -> envs_per_env_runner
sven1977 May 10, 2024
2934ac3
- IMPALA: make loss based on mean again (instead of sum).
sven1977 May 12, 2024
c23f0bd
wip
sven1977 May 12, 2024
23d9578
wip
sven1977 May 12, 2024
0f5d9da
fixes
sven1977 May 13, 2024
f483dcf
Merge branch 'remove_all_rl_module_checking' into appo_on_new_api_sta…
sven1977 May 13, 2024
ec427e3
limit `.reduce()` calls at end of training_step (if no new results ar…
sven1977 May 13, 2024
82c4d97
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 May 13, 2024
b3fdaff
merge
sven1977 May 13, 2024
a0c34d2
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 May 14, 2024
5ed5784
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 May 14, 2024
16ec8ad
wip
sven1977 May 15, 2024
0af3228
wip
sven1977 May 15, 2024
bbd5ec2
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 May 15, 2024
b15a6fe
merge
sven1977 Jun 4, 2024
2c604ea
merge
sven1977 Jun 4, 2024
712e3eb
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 4, 2024
d4f83a6
fixes
sven1977 Jun 4, 2024
092ddaa
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 5, 2024
6a3611d
wip
sven1977 Jun 5, 2024
6137312
wip
sven1977 Jun 5, 2024
6dcd3fc
num_gpu=0 -> lr=0 fix
sven1977 Jun 5, 2024
2ec939e
cleanup
sven1977 Jun 5, 2024
fec1e39
wip
sven1977 Jun 6, 2024
375c7ba
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 Jun 6, 2024
df4192b
wip
sven1977 Jun 6, 2024
1597425
merge
sven1977 Jun 6, 2024
65f32d2
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 Jun 6, 2024
b8b459f
LINT and test case and docstring
sven1977 Jun 6, 2024
682da4d
SAEnvRunner bug fix overriding an already provided model_config_dict …
sven1977 Jun 6, 2024
2b8f2d0
Merge branch 'cleanup_examples_folder_10_custom_rl_module.py' into ap…
sven1977 Jun 6, 2024
eef28b8
SAEnvRunner bug fix overriding an already provided model_config_dict …
sven1977 Jun 6, 2024
926eff4
finetune tiny CNN Atari (Pong) example
sven1977 Jun 6, 2024
239091b
better init of last logits layer in tiny CNN
sven1977 Jun 6, 2024
2f5b37d
wip
sven1977 Jun 7, 2024
a22d119
further fine tune
sven1977 Jun 7, 2024
5e191f6
fix
sven1977 Jun 7, 2024
2b72814
LINT
sven1977 Jun 7, 2024
4e7033d
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 7, 2024
fc9c0e1
LINT
sven1977 Jun 7, 2024
f1c4cb8
wip
sven1977 Jun 7, 2024
32c1b96
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 14, 2024
31aaa8b
wip
sven1977 Jun 14, 2024
293b417
wip
sven1977 Jun 14, 2024
3fd3b52
wip
sven1977 Jun 14, 2024
03cdbc5
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 17, 2024
6eb2547
wip
sven1977 Jun 17, 2024
3863383
wip
sven1977 Jun 17, 2024
097ea43
wip
sven1977 Jun 17, 2024
44224ce
wip
sven1977 Jun 17, 2024
1a3e054
wip
sven1977 Jun 17, 2024
6a417ff
wip
sven1977 Jun 18, 2024
f666e3a
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 18, 2024
1cb2a32
wip
sven1977 Jun 18, 2024
67476ee
wip
sven1977 Jun 18, 2024
73ed1f9
wip
sven1977 Jun 18, 2024
5efd05e
wip
sven1977 Jun 18, 2024
36eaaf9
wip
sven1977 Jun 18, 2024
1b2a24b
LINT
sven1977 Jun 18, 2024
20007eb
wip
sven1977 Jun 18, 2024
91e16f9
wip
sven1977 Jun 18, 2024
b799260
wip
sven1977 Jun 18, 2024
954d38e
wip
sven1977 Jun 18, 2024
499e4ea
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Jun 19, 2024
fb48956
wip
sven1977 Jun 19, 2024
1eb7205
wip
sven1977 Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/lint/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ else
echo "WARNING: clang-format is not installed!"
fi

if command -v java >/dev/null; then
if 0; then #command -v java >/dev/null; then
if [ ! -f "$GOOGLE_JAVA_FORMAT_JAR" ]; then
echo "Java code format tool google-java-format.jar is not installed, start to install it."
wget https://github.com/google/google-java-format/releases/download/google-java-format-1.7/google-java-format-1.7-all-deps.jar -O "$GOOGLE_JAVA_FORMAT_JAR"
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/_internal/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def construct_metadata() -> WorkerMetadata:
node_id = ray.get_runtime_context().get_node_id()
node_ip = ray.util.get_node_ip_address()
hostname = socket.gethostname()
accelerator_ids = ray.get_runtime_context().get_accelerator_ids()
accelerator_ids = ray.get_runtime_context().get_resource_ids()
pid = os.getpid()

return WorkerMetadata(
Expand Down
30 changes: 20 additions & 10 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ py_test(
# --------------------------------------------------------------------

# APPO
py_test(
name = "learning_tests_cartpole_appo_w_env_runner",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/cartpole_appo_envrunner.py"],
args = ["--dir=tuned_examples/appo/"]
)

py_test(
name = "learning_tests_cartpole_appo_w_rl_modules_and_learner",
main = "tests/run_regression_tests.py",
Expand Down Expand Up @@ -301,15 +311,15 @@ py_test(
)

# IMPALA
# py_test(
# name = "learning_tests_cartpole_impala",
# main = "tests/run_regression_tests.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
# size = "large",
# srcs = ["tests/run_regression_tests.py"],
# data = ["tuned_examples/impala/cartpole-impala.yaml"],
# args = ["--dir=tuned_examples/impala"]
# )
py_test(
name = "learning_tests_cartpole_impala_w_env_runner",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/impala/cartpole_impala_envrunner.py"],
args = ["--dir=tuned_examples/impala/"]
)

py_test(
name = "learning_tests_cartpole_separate_losses_impala",
Expand Down Expand Up @@ -590,7 +600,7 @@ py_test(
srcs = ["algorithms/dreamerv3/tests/test_dreamerv3.py"]
)

# Impala
# IMPALA
py_test(
name = "test_impala",
tags = ["team:rllib", "algorithms_dir"],
Expand Down
70 changes: 52 additions & 18 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def setup(self, config: AlgorithmConfig) -> None:
)

# Ensure remote workers are initially in sync with the local worker.
self.workers.sync_weights()
#self.workers.sync_weights()

# Compile, validate, and freeze an evaluation config.
self.evaluation_config = self.config.get_evaluation_config_object()
Expand Down Expand Up @@ -722,7 +722,6 @@ def setup(self, config: AlgorithmConfig) -> None:
# Need to add back method_type in case Algorithm is restored from checkpoint
method_config["type"] = method_type

self.learner_group = None
if self.config._enable_new_api_stack:
local_worker = self.workers.local_worker()
env = spaces = None
Expand Down Expand Up @@ -781,11 +780,16 @@ def setup(self, config: AlgorithmConfig) -> None:
lambda w: w.set_is_policy_to_train(policies_to_train),
healthy_only=True,
)

# Sync the weights from the learner group to the rollout workers.
weights = self.learner_group.get_weights()
local_worker.set_weights(weights)
self.workers.sync_weights()
# Sync the weights from the learner group to the rollout workers.
weights = self.learner_group.get_weights()
local_worker.set_weights(weights)
self.workers.sync_weights()
# New stack/EnvRunner APIs: Use get/set_state (no more get/set_weights).
else:
# Sync the weights from the learner group to the rollout workers.
weights = self.learner_group.get_weights()
local_worker.set_state({"rl_module": weights})
self.workers.sync_weights()

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)
Expand Down Expand Up @@ -876,11 +880,13 @@ def step(self) -> ResultDict:
config=self.config,
)

episodes_this_iter = collect_episodes(
self.workers,
self._remote_worker_ids_for_metrics(),
timeout_seconds=self.config.metrics_episode_collection_timeout_s,
)
episodes_this_iter = results.pop("_episodes_this_iter", None)
if episodes_this_iter is None:
episodes_this_iter = collect_episodes(
self.workers,
self._remote_worker_ids_for_metrics(),
timeout_seconds=self.config.metrics_episode_collection_timeout_s,
)
results = self._compile_iteration_results(
episodes_this_iter=episodes_this_iter,
step_ctx=train_iter_ctx,
Expand Down Expand Up @@ -1386,9 +1392,15 @@ def _evaluate_async_with_env_runner(
with self._timers[SYNCH_ENV_CONNECTOR_STATES_TIMER]:
# Merge connector states from all EnvRunners and broadcast updated
# states back to all EnvRunners.
self.evaluation_workers.sync_env_runner_states(
from_worker=self.workers.local_worker(),
env_steps_sampled=self._counters[NUM_ENV_STEPS_SAMPLED],
self.evaluation_workers.broadcast_state(
state={
self.workers.local_worker().get_state(components=[
NUM_ENV_STEPS_SAMPLED,
"env_to_module_connector",
"module_to_env_connector",
])
},
local_worker=True,
)

if self.evaluation_workers is None and (
Expand Down Expand Up @@ -3182,19 +3194,41 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
if self.config.get("framework") == "tf2" and not tf.executing_eagerly():
tf1.enable_eager_execution()

results = None
results = {}
training_step_results = {}
episodes_this_iter = []
# Create a step context ...
with TrainIterCtx(algo=self) as train_iter_ctx:
# .. so we can query it whether we should stop the iteration loop (e.g.
# when we have reached `min_time_s_per_iteration`).
while not train_iter_ctx.should_stop(results):
while not train_iter_ctx.should_stop(training_step_results):
# Try to train one step.
with self._timers[TRAINING_ITERATION_TIMER]:
results = self.training_step()
# TODO (sven): Add capability to reduce results over different
# iterations.
training_step_results = self.training_step()

# Collect returned episode metrics from each `trainin_step` call,
# so nothing gets lost (in this mode, we do NOT call get_metrics()
# here automatically, it has already been done by the
# `training_step` method).
if "_episodes_this_training_step" in training_step_results:
episodes_this_iter.extend(
training_step_results.pop("_episodes_this_training_step")
)

if training_step_results:
results = training_step_results

# With training step done. Try to bring failed workers back.
self.restore_workers(self.workers)

# Publish all episodes collected in this entire iteration (consisting of n
# `training_step` calls) to let the algo know, we do NOT have to call
# `get_metrics` anymore on all EnvRunners (already done inside `training_step`).
if episodes_this_iter:
results["_episodes_this_iter"] = episodes_this_iter

return results, train_iter_ctx

def _run_one_evaluation(
Expand Down
6 changes: 5 additions & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def validate(self) -> None:
# Check to-be-deprecated settings (however that are still in use).
self._validate_to_be_deprecated_settings()

def build(
def build_algorithm(
self,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
Expand Down Expand Up @@ -4295,6 +4295,10 @@ def _resolve_tf_settings(self, _tf1, _tfv):
"speed as with static-graph mode."
)

@Deprecated(new="AlgorithmConfig.build_algorithm()", error=False)
def build(self, *args, **kwargs):
return self.build_algorithm(*args, **kwargs)

@property
@Deprecated(
old="AlgorithmConfig.multiagent['[some key]']",
Expand Down
Loading
Loading