-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Algorithm/Policy checkpoint overhaul and Policy Model export (in native formats). #28166
Conversation
…l_export_overhaul # Conflicts: # rllib/BUILD
…l_export_overhaul
…l_export_overhaul Signed-off-by: sven1977 <[email protected]> # Conflicts: # rllib/examples/export/onnx_tf.py # rllib/examples/export/onnx_torch.py # rllib/policy/dynamic_tf_policy_v2.py # rllib/policy/eager_tf_policy_v2.py # rllib/policy/tests/test_policy.py # rllib/policy/torch_policy_v2.py
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
rllib/evaluation/rollout_worker.py
Outdated
@@ -1457,7 +1457,7 @@ def get_filters(self, flush_after: bool = False) -> Dict: | |||
return return_filters | |||
|
|||
@DeveloperAPI | |||
def save(self) -> bytes: | |||
def get_state(self) -> bytes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unify naming:
get_state
-> returns dict that completely describes the current state of the object
set_state(state)
-> can be used to restore to the state (which may have been returned from a get_state call)
save
-> creates checkpoint on disk and returns path where stuff was saved
rllib/evaluation/rollout_worker.py
Outdated
@@ -1486,7 +1487,7 @@ def save(self) -> bytes: | |||
) | |||
|
|||
@DeveloperAPI | |||
def restore(self, objs: bytes) -> None: | |||
def set_state(self, objs: bytes) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -1,59 +1,75 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned up our ONNX examples to better fit our typical example scripts format.
rllib/policy/dynamic_tf_policy_v2.py
Outdated
state = super().get_state() | ||
# Add this Policy's spec so it can be retreived w/o access to the original | ||
# code. | ||
state["policy_spec"] = PolicySpec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New:
- Add full PolicySpec to state, such that a policy can be recovered w/o(!) knowing the config/class.
- Add native (keras/torch) model files resulting from a
save
operation on this native model.
# Reset policy to its original state and compare. | ||
policy.set_state(state1) | ||
state3 = policy.get_state() | ||
# Make sure everything is the same. | ||
check(state1, state3) | ||
check(state1["_exploration_state"], state3["_exploration_state"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expanded test a little.
Signed-off-by: sven1977 <[email protected]>
@@ -2912,6 +2912,16 @@ py_test( | |||
tags = ["team:rllib", "exclusive", "examples", "examples_E", "no_main"], | |||
size = "medium", | |||
srcs = ["examples/export/onnx_tf.py"], | |||
args = ["--framework=tf"], | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added tf2 example for exporting to ONNX format.
rllib/evaluation/rollout_worker.py
Outdated
@@ -1457,7 +1457,7 @@ def get_filters(self, flush_after: bool = False) -> Dict: | |||
return return_filters | |||
|
|||
@DeveloperAPI | |||
def save(self) -> bytes: | |||
def get_state(self) -> bytes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed for clarity:
get_state()
-> always return dict (see Policy or Algorithm)
save()
-> Saves to disk and returns path.
rllib/evaluation/rollout_worker.py
Outdated
@@ -1917,6 +1918,14 @@ def export_policy_checkpoint( | |||
def foreach_trainable_policy(self, func, **kwargs): | |||
return self.foreach_policy_to_train(func, **kwargs) | |||
|
|||
@Deprecated(new="RolloutWorker.get_state()", error=False) | |||
def save(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed: See comments above.
Signed-off-by: sven1977 <[email protected]>
rllib/policy/eager_tf_policy.py
Outdated
state = super().get_state() | ||
|
||
# Add this Policy's spec so it can be retreived w/o access to the original |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add PolicySpec AND - if possible - native keras/torch saved model files to returned state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PolicySpec is already part of the dict that is constructed by (newly renamed) RolloutWorker.get_state()
So I don't think we need to save these again? which will probably allow you to avoid a lot of the changes here.
btw, just a related note, we should try to use PolicySpec.serialize() whenever possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/policy/eager_tf_policy_v2.py
Outdated
state = super().get_state() | ||
|
||
# Add this Policy's spec so it can be retreived w/o access to the original |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add PolicySpec AND - if possible - native keras/torch saved model files to returned state.
1 sounds good 👍 |
ah, for 3, I see that you already have simple validation of policy_id, that's great. this is good enough for me for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, just have a few last nits left.
thanks for all the hard work, this is looking great.
# Write individual policies to disk, each in their own sub-directory. | ||
for pid, policy_state in policy_states.items(): | ||
# From here on, disallow policyIDs that would not work as directory names. | ||
validate_policy_id(pid, error=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please run this validation in Algorithm.validate_config()
and also add_policy()
?
it's too late if we do this right between checkpointing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was already in add_policy
.
Added it to check_multi_agent
, which is called from within Algorithm.validate_config
. Good catch!
rllib/policy/policy.py
Outdated
worker_state["policy_config"], serialized_policy_spec["config"] | ||
) | ||
serialized_policy_spec.update({"config": policy_config}) | ||
policy_state.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel like this is not necessary, if we have already updated serailzed_policy_spec.
>>> a = {"b": {"c": 1}}
>>> p = a["b"]
>>> p["c"] = 2
>>> a
{'b': {'c': 2}}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True! Fixed.
policies = Policy.from_checkpoint(path_to_checkpoint) | ||
assert "default_policy" in policies | ||
|
||
print(algo.train()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you intend to keep the print() here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, prints of train results is pretty common in our tests.
tf1, tf, tfv = try_import_tf() | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
def validate_policy_id(policy_id: str, error: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please write a unit test for this util?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really necessary? This is a single line, one reg exp match function! ??
rllib/utils/files.py
Outdated
from typing import Dict | ||
|
||
|
||
def dir_contents_to_dict(dir: str) -> Dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment here about we only use this util for unit tests right now, and please be careful using it in production, since it's not really been battle tested.
|
||
|
||
@PublicAPI(stability="alpha") | ||
def get_checkpoint_info(checkpoint) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think you can start a unit test file for this file?
write a simple checkpoint, then expect the parsed info in certain shape, etc.
this is so it will be a lot more convenient for us to add unit tests in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the whole thing (utils/files.py). It's not need anymore, not even our test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
On 2) above:
We don't have the worker yet! :) And the whole point is to avoid having to generate the entire policy_map before we decide to only create a few policies (then we would have to remove some policies again). One of the main purposes of this PR is to NOT have to create all policies that have been stored in an Algo checkpoint, but to select which subset of policies to kick-start the Algo with, given a checkpoint. |
1 similar comment
On 2) above:
We don't have the worker yet! :) And the whole point is to avoid having to generate the entire policy_map before we decide to only create a few policies (then we would have to remove some policies again). One of the main purposes of this PR is to NOT have to create all policies that have been stored in an Algo checkpoint, but to select which subset of policies to kick-start the Algo with, given a checkpoint. |
Signed-off-by: sven1977 <[email protected]>
ah, yeah, my bad, I got confused from where this is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok man, thanks for your patience.
let's merge :)
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…l_export_overhaul
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
… export (in native formats). (ray-project#28166)" This reverts commit 23b3a59.
…(in native formats). (ray-project#28166) Signed-off-by: Weichen Xu <[email protected]>
IMPORTANT: A Documentation PR (#28812) is in progress and will be merged right after this one.
This PR provides a complete overhaul of RLlib's checkpointing (Algorithms and individual Policies) mechanisms.
It also introduces the option to save the NN-models of a policy as native keras/PyTorch files within the Policy checkpoints.
Details:
On Algorithm checkpoints:
my_algo.restore([some AIR checkpoint])
works as well asAlgorithm.restore([some path to a checkpoint dir])
. The checkpoint directory structure will change from:to:
from_checkpoint()
andfrom_state()
, both of which return new Algorithm objects, given a checkpoint dir or object or a state dict, respectively. I.e.:my_new_algo = Algorithm.from_checkpoint([path to AIR checkpoint OR AIR checkpoint obj])
. No original config or other information is needed other than the checkpoint.On Policy Checkpoints:
Policy.export_checkpoint()
produces an AIR Checkpoint directory with all the policy's state in it.from_checkpoint()
andfrom_state()
, both of which return new Policy objects, given a Policy checkpoint dir or object or a Policy state, respectively.On native keras/PyTorch models being part of a Policy checkpoint (optional):
config.checkpointing(checkpoints_contain_native_model_files=True)
makes Policies also try to write their NN model as native keras/torch saved model into the given checkpoint directory (under sub-dir "model"). This may still fail (gracefully) in some cases, e.g. for certain TfModelV2 where the kerasself.base_model
(of the TfModelV2) cannot be discovered easily. This problem will be fully solved by the ongoing RLModule/RLTrainer API efforts.Additional cleanup:
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.