-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for RE3 exploration algorithm #19551
Conversation
rllib/agents/trainer.py
Outdated
) | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
|
||
embeds_dim = self.config["exploration_config"].get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we define all these configurations and the UpdateCallback with the RandomEncoder in that re3.py file under rllib/utils/exploration?
I don't think we need to configure the callback here automatically?
all we need is an example script showing how to construct a trainer with RE3 exploration using the exploration type field and a RE3UpdateCallback callback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that would be cleaner. I will make the changes.
rllib/policy/policy.py
Outdated
@@ -698,6 +698,7 @@ def _get_default_view_requirements(self): | |||
SampleBatch.UNROLL_ID: ViewRequirement(), | |||
SampleBatch.AGENT_INDEX: ViewRequirement(), | |||
"t": ViewRequirement(), | |||
SampleBatch.OBS_EMBEDS: ViewRequirement(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move above "t" maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
@@ -40,6 +40,7 @@ class SampleBatch(dict): | |||
DONES = "dones" | |||
INFOS = "infos" | |||
SEQ_LENS = "seq_lens" | |||
OBS_EMBEDS = "obs_embeds" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please comment that this field is only computed and used when RE3 exploration strategy is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
rllib/agents/trainer.py
Outdated
def on_train_result(self, *, trainer, result: dict, | ||
**kwargs) -> None: | ||
# Keep track of the training iteration for beta decay. | ||
UpdateCallbacks._step = result["training_iteration"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just use policy.global_timestep and there is no need to track this yourself?
another way is to Add a MixIn for this Beta schedule, much like LearningRateSchedule and EntropyCoeffSchedule. you can then just use the auto-decayed beta value here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, we want to decay based on training_iteration
. It seems policy.global_timestep
depends on number of batch processed and not equal to training_iteration
, just looking at source, not 100% sure.
rllib/agents/trainer.py
Outdated
**kwargs, | ||
): | ||
states_entropy = compute_states_entropy( | ||
train_batch[SampleBatch.OBS_EMBEDS], embeds_dim, k_nn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels a bit weird.
I guess to calculate the knn distance, you want to use a batch that's randomly sampled from the ReplayBuffer? just so you can measure how different current sample batch is from all the things you have been seeing (RB)?
as written, this knn thing is calculated from a single batch. so naturally, you are gonna get a lot of similar samples from consecutive steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, KNN is computed for a single batch and we are assuming that RLLib has sampled that batch randomly from replay buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KNN distance and the randomness is limited to single batch only, not the full replay buffer.
@gjoliver Could you take another look and indicate if anything else needs to be changed? Thanks! |
sorry about the delay, I will take another detailed look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks very exciting. thanks for your awesome work.
I have gone through the PR carefully, and had a bunch of minor / detailed comments. nothing significant actually, except maybe for a TODO for us :)
The biggest blocker right now is to rebase to master, and make sure all the tests pass. there are too many failed CI tests right now.
This PR doesn't introduce any changes to existing codebase, so I don't expect anything tricky. Probably just need to retry by doing:
fetch upstream in your repo
git checkout master; git pull --rebase
git checkout re3; git rebase master
then git push to re-trigger the tests.
Also, please make sure you run ci/travis/format.sh so Lint test will pass too.
Thanks again. Let me know if you need any help with the tests actually.
Happy to help on those logistic things.
rllib/examples/re3_exploration.py
Outdated
ray.init() | ||
|
||
config = sac.DEFAULT_CONFIG.copy() | ||
beta_schedule = "linear_decay" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this variable doesn't seem to be used.
rllib/examples/re3_exploration.py
Outdated
|
||
# Patch user given callbacks with RE3 callbacks for using RE3 exploration | ||
# strategy | ||
class RE3Callbacks(RE3UpdateCallbacks, config["callbacks"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sac.DEFAULT_CONFIG doesn't specify callbacks?
it seems like you can simply do:
config["callbacks"] = partial(RE3Callbacks, embeds_dim=128,
beta_schedule="linear_decay", k_nn=50)
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was used for demonstration purpose, in case user wants to know how it works when callbacks is provided in configs.
rllib/agents/trainer.py
Outdated
@@ -764,6 +764,7 @@ def env_creator_from_classpath(env_context): | |||
"`callbacks` must be a callable method that " | |||
"returns a subclass of DefaultCallbacks, got {}".format( | |||
self.config["callbacks"])) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you undo this empty line diff too? thanks a lot.
delta = batch_mean - self.mean | ||
tot_count = self.count + batch_count | ||
|
||
self.mean = self.mean + delta + batch_count / tot_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm sorry for the stupid question, but I don't quite get the formula for the moving mean here. as written, self.mean basically is:
self.mean + batch_mean - self.mean + batch_count / tot_count
which is:
batch_mean + batch_count / tot_count
why are we taking batch_mean plus a small ratio as the moving mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this construct is adopted author's reference code. I have also experimented using the correct moving mean ((self.mean * self.count + batch_mean * batch_count) / tot_count
) version, results are almost similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok. it's not uncommon to see questionable reference implementation.
we can keep it as is for now.
Updated beta as per input schedule. | ||
""" | ||
if beta_schedule == "linear_decay": | ||
return beta * ((1.0 - rho)**step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add whitespaces around ** ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ci/travis/format.sh
script removes the whitespaces automatically.
|
||
The entropy of a state is considered as intrinsic reward and added to the | ||
environment's extrinsic reward for policy optimization. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we mention here that entropy is only calculated per batch, and does not take the distribution of the entire replay buffer into consideration?
basically one of your responses, we just put it here.
|
||
Args: | ||
action_space (Space): The action space in which to explore. | ||
framework (str): One of "tf" or "torch". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the comment? this implementation doesn't support torch.
|
||
|
||
if __name__ == "__main__": | ||
import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move import to top of file?
|
||
self.sub_exploration = sub_exploration | ||
|
||
# Creates modules/layers inside the actual ModelV2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit picking.
it's not really "inside" ModelV2?
more like "Creates ModelV2 embedding module / layers"?
@gjoliver I have updated the code with requested changes, please have a look. It would be great if you can help with the failing tests. Not sure why these tests are failing. Thank you. |
I can totally help. can you give me edit permission to this repo? thanks. |
I've added you to the repository. Let me know if it does not work. |
…ate example to use MultiCallbacks.
…actually enabled.
Tests are now all green! |
@@ -381,3 +385,59 @@ def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, | |||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: | |||
for callback in self._callback_list: | |||
callback.on_train_result(trainer=trainer, result=result, **kwargs) | |||
|
|||
|
|||
# This Callback is used by the RE3 exploration strategy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move this into utils/exploration/random_encoder.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was moved out from here to avoid circular dependency in this commit 89169cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this is the best I could do.
I would love to keep everything in random_encoder.py too, but you know :)
@@ -0,0 +1,67 @@ | |||
import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Could we actually activate this test? By adding it to rllib/BUILD
to the other utils/exploration
tests?
This will make sure this feature never breaks in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @n30111 , thanks for this really great PR! Just 2-3 smaller nits, like activating the test case, etc...
In a follow up PR, could you also add this exploration algo to our algo docs? There is a special section in there describing the different exploration strategies.
Sure, I will update the docs in a follow up PR. Thanks. |
@gjoliver Can you please help again on the tests part, not sure why these tests are failing. Thanks. |
can you update your test? the error message looks like: from ray.rllib.utils.exploration.random_encoder import RE3UpdateCallbacks| ImportError: cannot import name 'RE3UpdateCallbacks' from 'ray.rllib.utils.exploration.random_encoder' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks! :)
@@ -679,7 +679,8 @@ def _initialize_loss_from_dummy_batch( | |||
key not in [ | |||
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, | |||
SampleBatch.UNROLL_ID, SampleBatch.DONES, | |||
SampleBatch.REWARDS, SampleBatch.INFOS]: | |||
SampleBatch.REWARDS, SampleBatch.INFOS, | |||
SampleBatch.OBS_EMBEDS]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to do this rn, but in a follow up PR, we should also test-run the callbacks, so that the SampleBatch access detector captures this field and automatically adds it to the view-requirements (instead of always doing so).
Merging this now as all tests look good. |
Thanks, great. |
Why are these changes needed?
This PR adds support for RE3 (Random Encoders for Efficient Exploration). RE3 is a simple and efficient algorithm for off-policy RL algorithms. This can also be used with on-policy RL algorithms.
This is the implementation of State entropy maximization with random encoders for efficient exploration. Seo, Chen, Shin, Lee, Abbeel, & Lee, (2021). arXiv preprint arXiv:2102.09430.
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.