-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] TF2 Bandit Agent #22838
[RLlib] TF2 Bandit Agent #22838
Conversation
tf1, tf, tfv = try_import_tf() | ||
|
||
|
||
class OnlineLinearRegression(tf.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.Module if tf is not None else object
(in case tf is not installed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if tf is not installed, I feel like we should just error whenever possible?
we need this module to be tf.Module so the weights of all the arms can be checkpointed actually.
@@ -32,7 +32,12 @@ def plot_model_weights(means, covs): | |||
if __name__ == "__main__": | |||
num_iter = 10 | |||
print("Running training for %s time steps" % num_iter) | |||
trainer = BanditLinTSTrainer(env=WheelBanditEnv) | |||
config = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this a command line arg, like in most other example scripts?
--framework=[tf|tf2|torch]
and --eager-tracing
.
You can copy paste this from any other example script, e.g. rllib/examples/attention_nets.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Awesome to have Bandits frameworks-complete :)
Just a few nits before merging. Thank you for this PR Jun!
Why are these changes needed?
TF2 version of Bandit.
I also think these Bandit models need to save value outputs after applying UCB updates. Otherwise the UCB exploration won't really work.
Please help double check.
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.