Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLLib] RE3 exploration algorithm TF2 framework support #25221

Merged
merged 26 commits into from
Jul 24, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e38a825
add re3 documentation
n3011 Dec 18, 2021
0e688a0
remove empty lines
n3011 Dec 18, 2021
0170b61
Apply suggestions from code review
n30111 Dec 20, 2021
5218988
update doc string
n3011 Dec 20, 2021
530f600
Apply suggestions from code review
n30111 Dec 20, 2021
4c4c498
Merge pull request #1 from minds-ai/re3_doc
n30111 Dec 20, 2021
4f689b3
address pr comments
n3011 Dec 24, 2021
e1e78b8
Merge branch 'ray-project:master' into master
n30111 Jan 20, 2022
9168d70
Merge branch 'ray-project:master' into master
n30111 Feb 3, 2022
a36c419
Merge branch 'ray-project:master' into master
n30111 Mar 1, 2022
a6fa0dd
Merge branch 'ray-project:master' into master
n30111 Mar 1, 2022
5b8b2cf
Merge branch 'ray-project:master' into master
jbedorf Mar 11, 2022
09de5b9
Merge branch 'ray-project:master' into master
n30111 Apr 11, 2022
1beec09
Merge branch 'ray-project:master' into master
jbedorf Apr 19, 2022
6d14e7f
Merge branch 'ray-project:master' into master
n30111 May 24, 2022
4929b7d
fix tf2 issue
n3011 May 24, 2022
953061b
Merge branch 'ray-project:master' into master
n30111 May 26, 2022
2100e81
Merge branch 'master' of https://github.com/minds-ai/ray into re3_tf2…
n3011 May 27, 2022
bdeb776
Merge branch 'ray-project:master' into master
n30111 May 27, 2022
7403237
Merge branch 'master' of https://github.com/minds-ai/ray into re3_tf2…
n3011 May 27, 2022
bf0c123
get numpy array from tensor
n3011 May 27, 2022
09dee4d
Merge branch 'ray-project:master' into master
n30111 Jun 14, 2022
e47fb12
Merge branch 'ray-project:master' into master
n30111 Jun 24, 2022
c9bd456
merge with master
n3011 Jun 24, 2022
27b12ab
Merge branch 'ray-project:master' into master
n30111 Jul 15, 2022
5113a0f
fix conflicts
n3011 Jul 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rllib/utils/exploration/random_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def compute_states_entropy(
"""
obs_embeds_ = np.reshape(obs_embeds, [-1, embed_dim])
dist = np.linalg.norm(obs_embeds_[:, None, :] - obs_embeds_[None, :, :], axis=-1)
return dist.argsort(axis=-1)[:, :k_nn][:, -1]
return dist.argsort(axis=-1)[:, :k_nn][:, -1].astype(np.float32)


@PublicAPI
Expand Down Expand Up @@ -288,6 +288,6 @@ def _postprocess_tf(self, policy, sample_batch, tf_sess):
else:
obs_embeds = tf.stop_gradient(
self._encoder_net({SampleBatch.OBS: sample_batch[SampleBatch.OBS]})[0]
)
).numpy()
sample_batch[SampleBatch.OBS_EMBEDS] = obs_embeds
return sample_batch
24 changes: 13 additions & 11 deletions rllib/utils/exploration/tests/test_random_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import ray
from ray.rllib.utils.test_utils import framework_iterator
import ray.rllib.algorithms.ppo as ppo
import ray.rllib.algorithms.sac as sac
from ray.rllib.algorithms.callbacks import RE3UpdateCallbacks
Expand Down Expand Up @@ -48,17 +49,18 @@ class RE3Callbacks(RE3UpdateCallbacks, config["callbacks"]):
}

num_iterations = 30
algo = algo_cls(config=config)
learnt = False
for i in range(num_iterations):
result = algo.train()
print(result)
if result["episode_reward_max"] > -900.0:
print("Reached goal after {} iters!".format(i))
learnt = True
break
algo.stop()
self.assertTrue(learnt)
for _ in framework_iterator(config, frameworks=("tf", "tf2"), session=True):
algo = algo_cls(config=config)
learnt = False
for i in range(num_iterations):
result = algo.train()
print(result)
if result["episode_reward_max"] > -900.0:
print("Reached goal after {} iters!".format(i))
learnt = True
break
algo.stop()
self.assertTrue(learnt)

def test_re3_ppo(self):
"""Tests RE3 with PPO."""
Expand Down