Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for RE3 exploration algorithm #19551

Merged
merged 13 commits into from
Dec 7, 2021
Merged

Conversation

n30111
Copy link
Contributor

@n30111 n30111 commented Oct 20, 2021

Why are these changes needed?

This PR adds support for RE3 (Random Encoders for Efficient Exploration). RE3 is a simple and efficient algorithm for off-policy RL algorithms. This can also be used with on-policy RL algorithms.

This is the implementation of State entropy maximization with random encoders for efficient exploration. Seo, Chen, Shin, Lee, Abbeel, & Lee, (2021). arXiv preprint arXiv:2102.09430.

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@gjoliver gjoliver self-requested a review November 4, 2021 03:50
)
from ray.rllib.policy.sample_batch import SampleBatch

embeds_dim = self.config["exploration_config"].get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we define all these configurations and the UpdateCallback with the RandomEncoder in that re3.py file under rllib/utils/exploration?
I don't think we need to configure the callback here automatically?
all we need is an example script showing how to construct a trainer with RE3 exploration using the exploration type field and a RE3UpdateCallback callback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that would be cleaner. I will make the changes.

@@ -698,6 +698,7 @@ def _get_default_view_requirements(self):
SampleBatch.UNROLL_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
"t": ViewRequirement(),
SampleBatch.OBS_EMBEDS: ViewRequirement(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move above "t" maybe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@@ -40,6 +40,7 @@ class SampleBatch(dict):
DONES = "dones"
INFOS = "infos"
SEQ_LENS = "seq_lens"
OBS_EMBEDS = "obs_embeds"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please comment that this field is only computed and used when RE3 exploration strategy is enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

def on_train_result(self, *, trainer, result: dict,
**kwargs) -> None:
# Keep track of the training iteration for beta decay.
UpdateCallbacks._step = result["training_iteration"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just use policy.global_timestep and there is no need to track this yourself?
another way is to Add a MixIn for this Beta schedule, much like LearningRateSchedule and EntropyCoeffSchedule. you can then just use the auto-decayed beta value here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we want to decay based on training_iteration. It seems policy.global_timestep depends on number of batch processed and not equal to training_iteration, just looking at source, not 100% sure.

**kwargs,
):
states_entropy = compute_states_entropy(
train_batch[SampleBatch.OBS_EMBEDS], embeds_dim, k_nn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels a bit weird.
I guess to calculate the knn distance, you want to use a batch that's randomly sampled from the ReplayBuffer? just so you can measure how different current sample batch is from all the things you have been seeing (RB)?
as written, this knn thing is calculated from a single batch. so naturally, you are gonna get a lot of similar samples from consecutive steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, KNN is computed for a single batch and we are assuming that RLLib has sampled that batch randomly from replay buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KNN distance and the randomness is limited to single batch only, not the full replay buffer.

@jbedorf
Copy link
Contributor

jbedorf commented Nov 29, 2021

@gjoliver Could you take another look and indicate if anything else needs to be changed? Thanks!

@gjoliver
Copy link
Member

sorry about the delay, I will take another detailed look.

Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks very exciting. thanks for your awesome work.

I have gone through the PR carefully, and had a bunch of minor / detailed comments. nothing significant actually, except maybe for a TODO for us :)

The biggest blocker right now is to rebase to master, and make sure all the tests pass. there are too many failed CI tests right now.
This PR doesn't introduce any changes to existing codebase, so I don't expect anything tricky. Probably just need to retry by doing:

fetch upstream in your repo
git checkout master; git pull --rebase
git checkout re3; git rebase master
then git push to re-trigger the tests.

Also, please make sure you run ci/travis/format.sh so Lint test will pass too.

Thanks again. Let me know if you need any help with the tests actually.
Happy to help on those logistic things.

ray.init()

config = sac.DEFAULT_CONFIG.copy()
beta_schedule = "linear_decay"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable doesn't seem to be used.


# Patch user given callbacks with RE3 callbacks for using RE3 exploration
# strategy
class RE3Callbacks(RE3UpdateCallbacks, config["callbacks"]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sac.DEFAULT_CONFIG doesn't specify callbacks?
it seems like you can simply do:

config["callbacks"] = partial(RE3Callbacks, embeds_dim=128,
    beta_schedule="linear_decay", k_nn=50)

here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was used for demonstration purpose, in case user wants to know how it works when callbacks is provided in configs.

@@ -764,6 +764,7 @@ def env_creator_from_classpath(env_context):
"`callbacks` must be a callable method that "
"returns a subclass of DefaultCallbacks, got {}".format(
self.config["callbacks"]))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you undo this empty line diff too? thanks a lot.

delta = batch_mean - self.mean
tot_count = self.count + batch_count

self.mean = self.mean + delta + batch_count / tot_count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm sorry for the stupid question, but I don't quite get the formula for the moving mean here. as written, self.mean basically is:

self.mean + batch_mean - self.mean + batch_count / tot_count
which is:
batch_mean + batch_count / tot_count

why are we taking batch_mean plus a small ratio as the moving mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this construct is adopted author's reference code. I have also experimented using the correct moving mean ((self.mean * self.count + batch_mean * batch_count) / tot_count) version, results are almost similar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok. it's not uncommon to see questionable reference implementation.
we can keep it as is for now.

Updated beta as per input schedule.
"""
if beta_schedule == "linear_decay":
return beta * ((1.0 - rho)**step)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add whitespaces around ** ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ci/travis/format.sh script removes the whitespaces automatically.


The entropy of a state is considered as intrinsic reward and added to the
environment's extrinsic reward for policy optimization.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we mention here that entropy is only calculated per batch, and does not take the distribution of the entire replay buffer into consideration?
basically one of your responses, we just put it here.


Args:
action_space (Space): The action space in which to explore.
framework (str): One of "tf" or "torch".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the comment? this implementation doesn't support torch.

rllib/utils/exploration/random_encoder.py Show resolved Hide resolved


if __name__ == "__main__":
import pytest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move import to top of file?


self.sub_exploration = sub_exploration

# Creates modules/layers inside the actual ModelV2.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit picking.
it's not really "inside" ModelV2?
more like "Creates ModelV2 embedding module / layers"?

@n30111
Copy link
Contributor Author

n30111 commented Dec 2, 2021

@gjoliver I have updated the code with requested changes, please have a look. It would be great if you can help with the failing tests. Not sure why these tests are failing. Thank you.

@gjoliver
Copy link
Member

gjoliver commented Dec 2, 2021

I can totally help. can you give me edit permission to this repo?
or if you can help pull down the latest master, that will be awesome.
I will push some minor changes to your branch.

thanks.

@jbedorf
Copy link
Contributor

jbedorf commented Dec 2, 2021

I can totally help. can you give me edit permission to this repo? or if you can help pull down the latest master, that will be awesome. I will push some minor changes to your branch.

thanks.

I've added you to the repository. Let me know if it does not work.

@gjoliver
Copy link
Member

gjoliver commented Dec 3, 2021

Tests are now all green!

@@ -381,3 +385,59 @@ def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
for callback in self._callback_list:
callback.on_train_result(trainer=trainer, result=result, **kwargs)


# This Callback is used by the RE3 exploration strategy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this into utils/exploration/random_encoder.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was moved out from here to avoid circular dependency in this commit 89169cc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this is the best I could do.
I would love to keep everything in random_encoder.py too, but you know :)

@@ -0,0 +1,67 @@
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Could we actually activate this test? By adding it to rllib/BUILD to the other utils/exploration tests?
This will make sure this feature never breaks in the future.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @n30111 , thanks for this really great PR! Just 2-3 smaller nits, like activating the test case, etc...
In a follow up PR, could you also add this exploration algo to our algo docs? There is a special section in there describing the different exploration strategies.

@n30111
Copy link
Contributor Author

n30111 commented Dec 6, 2021

Hey @n30111 , thanks for this really great PR! Just 2-3 smaller nits, like activating the test case, etc...
In a follow up PR, could you also add this exploration algo to our algo docs? There is a special section in there describing the different exploration strategies.

Sure, I will update the docs in a follow up PR. Thanks.

@n30111
Copy link
Contributor Author

n30111 commented Dec 7, 2021

@gjoliver Can you please help again on the tests part, not sure why these tests are failing. Thanks.

@gjoliver
Copy link
Member

gjoliver commented Dec 7, 2021

@gjoliver Can you please help again on the tests part, not sure why these tests are failing. Thanks.

can you update your test? the error message looks like:

from ray.rllib.utils.exploration.random_encoder import RE3UpdateCallbacks

  | ImportError: cannot import name 'RE3UpdateCallbacks' from 'ray.rllib.utils.exploration.random_encoder'

@n30111
Copy link
Contributor Author

n30111 commented Dec 7, 2021

@gjoliver I was getting error due to the changes in this commit 491ef89 . Do you think we need to revert this or is there any other solution?

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks! :)

@@ -679,7 +679,8 @@ def _initialize_loss_from_dummy_batch(
key not in [
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
SampleBatch.UNROLL_ID, SampleBatch.DONES,
SampleBatch.REWARDS, SampleBatch.INFOS]:
SampleBatch.REWARDS, SampleBatch.INFOS,
SampleBatch.OBS_EMBEDS]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to do this rn, but in a follow up PR, we should also test-run the callbacks, so that the SampleBatch access detector captures this field and automatically adds it to the view-requirements (instead of always doing so).

rllib/BUILD Show resolved Hide resolved
@sven1977
Copy link
Contributor

sven1977 commented Dec 7, 2021

Merging this now as all tests look good.
If you would like to add more changes, just do a follow-up PR.
Thanks again so much @n30111 . Really nice PR!

@sven1977 sven1977 merged commit 2868d1a into ray-project:master Dec 7, 2021
@n30111
Copy link
Contributor Author

n30111 commented Dec 7, 2021

Merging this now as all tests look good.
If you would like to add more changes, just do a follow-up PR.
Thanks again so much @n30111 . Really nice PR!

Thanks, great.

@n30111 n30111 mentioned this pull request Dec 20, 2021
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants