Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Algorithm/Policy checkpoint overhaul and Policy Model export (in native formats). #28166

Merged
merged 105 commits into from
Oct 6, 2022

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Aug 30, 2022

IMPORTANT: A Documentation PR (#28812) is in progress and will be merged right after this one.

This PR provides a complete overhaul of RLlib's checkpointing (Algorithms and individual Policies) mechanisms.
It also introduces the option to save the NN-models of a policy as native keras/PyTorch files within the Policy checkpoints.

Details:
On Algorithm checkpoints:

  • All Algorithm checkpoints now use the AIR Checkpoint mechanism. I.e. my_algo.restore([some AIR checkpoint]) works as well as Algorithm.restore([some path to a checkpoint dir]). The checkpoint directory structure will change from:
.
..
checkpoint-[some iter num]

to:

.
..
policies/
    policy_1/
        policy_state.pkl
    policy_2
        policy_state.pkl
checkpoint_version.txt
state.pkl
  • Algorithm checkpoints now have a version (e.g. "v0", "v1") stored in the checkpoint dir under "checkpoint_version.txt". This will help keeping checkpoint handling fully backward compatible from Ray 2.0 on. Test cases are introduced in this PR confirming this is and remains the case.
  • Algorithm checkpoints now contain a sub-directory ("policies") which has further sub-directories (named after the policies' IDs) that contain the individual policy checkpoints (see below). This allows for easier decomposition and re-assembly of Policies within an Algorithm checkpoint (e.g. restore an Algorithm from a checkpoint, but only with policies A and B, instead of the original A, B, and C, or restoring a Policy instance individually).
  • Algorithm gets two new static utilities: from_checkpoint() and from_state(), both of which return new Algorithm objects, given a checkpoint dir or object or a state dict, respectively. I.e.: my_new_algo = Algorithm.from_checkpoint([path to AIR checkpoint OR AIR checkpoint obj]). No original config or other information is needed other than the checkpoint.
  • Test cases have been added to keep checkpoint backward compatibility and to test these new utilities and dir structures.

On Policy Checkpoints:

  • Policy checkpoints now use the AIR Checkpoint mechanism. I.e. Policy.export_checkpoint() produces an AIR Checkpoint directory with all the policy's state in it.
  • Policy gets two new static utilities: from_checkpoint() and from_state(), both of which return new Policy objects, given a Policy checkpoint dir or object or a Policy state, respectively.

On native keras/PyTorch models being part of a Policy checkpoint (optional):

  • A new config option: config.checkpointing(checkpoints_contain_native_model_files=True) makes Policies also try to write their NN model as native keras/torch saved model into the given checkpoint directory (under sub-dir "model"). This may still fail (gracefully) in some cases, e.g. for certain TfModelV2 where the keras self.base_model (of the TfModelV2) cannot be discovered easily. This problem will be fully solved by the ongoing RLModule/RLTrainer API efforts.

Additional cleanup:

  • Deprecate TF1-style saving of entire tf-graph (incl. loss) when calling Policy.[export_model/export_checkpoint]. This was replaced in favor of trying to save the model directly in tf/torch native formats.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…l_export_overhaul

Signed-off-by: sven1977 <[email protected]>

# Conflicts:
#	rllib/examples/export/onnx_tf.py
#	rllib/examples/export/onnx_torch.py
#	rllib/policy/dynamic_tf_policy_v2.py
#	rllib/policy/eager_tf_policy_v2.py
#	rllib/policy/tests/test_policy.py
#	rllib/policy/torch_policy_v2.py
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@@ -1457,7 +1457,7 @@ def get_filters(self, flush_after: bool = False) -> Dict:
return return_filters

@DeveloperAPI
def save(self) -> bytes:
def get_state(self) -> bytes:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify naming:
get_state -> returns dict that completely describes the current state of the object
set_state(state) -> can be used to restore to the state (which may have been returned from a get_state call)
save -> creates checkpoint on disk and returns path where stuff was saved

@@ -1486,7 +1487,7 @@ def save(self) -> bytes:
)

@DeveloperAPI
def restore(self, objs: bytes) -> None:
def set_state(self, objs: bytes) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -1,59 +1,75 @@
import argparse
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned up our ONNX examples to better fit our typical example scripts format.

state = super().get_state()
# Add this Policy's spec so it can be retreived w/o access to the original
# code.
state["policy_spec"] = PolicySpec(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New:

  • Add full PolicySpec to state, such that a policy can be recovered w/o(!) knowing the config/class.
  • Add native (keras/torch) model files resulting from a save operation on this native model.

# Reset policy to its original state and compare.
policy.set_state(state1)
state3 = policy.get_state()
# Make sure everything is the same.
check(state1, state3)
check(state1["_exploration_state"], state3["_exploration_state"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded test a little.

Signed-off-by: sven1977 <[email protected]>
@@ -2912,6 +2912,16 @@ py_test(
tags = ["team:rllib", "exclusive", "examples", "examples_E", "no_main"],
size = "medium",
srcs = ["examples/export/onnx_tf.py"],
args = ["--framework=tf"],
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tf2 example for exporting to ONNX format.

@@ -1457,7 +1457,7 @@ def get_filters(self, flush_after: bool = False) -> Dict:
return return_filters

@DeveloperAPI
def save(self) -> bytes:
def get_state(self) -> bytes:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed for clarity:
get_state() -> always return dict (see Policy or Algorithm)
save() -> Saves to disk and returns path.

@@ -1917,6 +1918,14 @@ def export_policy_checkpoint(
def foreach_trainable_policy(self, func, **kwargs):
return self.foreach_policy_to_train(func, **kwargs)

@Deprecated(new="RolloutWorker.get_state()", error=False)
def save(self, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed: See comments above.

Signed-off-by: sven1977 <[email protected]>
state = super().get_state()

# Add this Policy's spec so it can be retreived w/o access to the original
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add PolicySpec AND - if possible - native keras/torch saved model files to returned state.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PolicySpec is already part of the dict that is constructed by (newly renamed) RolloutWorker.get_state()
So I don't think we need to save these again? which will probably allow you to avoid a lot of the changes here.
btw, just a related note, we should try to use PolicySpec.serialize() whenever possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

state = super().get_state()

# Add this Policy's spec so it can be retreived w/o access to the original
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add PolicySpec AND - if possible - native keras/torch saved model files to returned state.

@gjoliver
Copy link
Member

gjoliver commented Oct 5, 2022

Thanks for another round of review @gjoliver ! Here are some of my responses to the latest questions. The reset has been answered above (or resolved):

  • Algo registry: Yes, we should do this. I wasn't against it, just not entirely sure how bad it would be. But I think you are right. Algos change a lot on our end and this would mess up the stability. Let's do this in a separate PR, though. I don't want to add any more complexity to this one.
  • Fixed the policy_ids from filters. Instead, we are storing a separate list of policy_ids now in the worker state and use that list, then. This is cleaner and less confusing. Also will allow us to just deprecate the "filters" key in the worker state in the future.
  • PolicyIDs now must be file/dir name save. There is no way around this, in my opinion (otherwise, we cannot guarantee a unique(!), non-clashing mapping from policy ID to sub-dir name). It still works with older checkpoints that don't have directories and that have policy IDs that contain now illegal characters, but one would get a warning either way.

1 sounds good 👍
2 do we really need the additional policy_ids attribute? why can't we simply list(rollout_workers.policy_map.keys())?
3 agree with you that things will get hairy if we start sanitizing policy names.
I feel like we need to put a big TODO somewhere, so that we will start validating these policy names using libraries like https://pypi.org/project/pathvalidate/, if we want to enforce the rule, otherwise, 1. you may train for 30 mins before hitting this file name error. 2. "aa/bb/cc" is actually a valid name, you will just create the files a few levels deeper, which may break the restore logics.
an alternative is to always name the policy dirs p1, p2, ..., then put a little text file in each policy folder with the actual name.
in terms of scripting needs, this is also quite usable.

@gjoliver
Copy link
Member

gjoliver commented Oct 5, 2022

Thanks for another round of review @gjoliver ! Here are some of my responses to the latest questions. The reset has been answered above (or resolved):

  • Algo registry: Yes, we should do this. I wasn't against it, just not entirely sure how bad it would be. But I think you are right. Algos change a lot on our end and this would mess up the stability. Let's do this in a separate PR, though. I don't want to add any more complexity to this one.
  • Fixed the policy_ids from filters. Instead, we are storing a separate list of policy_ids now in the worker state and use that list, then. This is cleaner and less confusing. Also will allow us to just deprecate the "filters" key in the worker state in the future.
  • PolicyIDs now must be file/dir name save. There is no way around this, in my opinion (otherwise, we cannot guarantee a unique(!), non-clashing mapping from policy ID to sub-dir name). It still works with older checkpoints that don't have directories and that have policy IDs that contain now illegal characters, but one would get a warning either way.

1 sounds good 👍 2 do we really need the additional policy_ids attribute? why can't we simply list(rollout_workers.policy_map.keys())? 3 agree with you that things will get hairy if we start sanitizing policy names. I feel like we need to put a big TODO somewhere, so that we will start validating these policy names using libraries like https://pypi.org/project/pathvalidate/, if we want to enforce the rule, otherwise, 1. you may train for 30 mins before hitting this file name error. 2. "aa/bb/cc" is actually a valid name, you will just create the files a few levels deeper, which may break the restore logics. an alternative is to always name the policy dirs p1, p2, ..., then put a little text file in each policy folder with the actual name. in terms of scripting needs, this is also quite usable.

ah, for 3, I see that you already have simple validation of policy_id, that's great. this is good enough for me for now.

Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, just have a few last nits left.
thanks for all the hard work, this is looking great.

# Write individual policies to disk, each in their own sub-directory.
for pid, policy_state in policy_states.items():
# From here on, disallow policyIDs that would not work as directory names.
validate_policy_id(pid, error=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please run this validation in Algorithm.validate_config() and also add_policy()?
it's too late if we do this right between checkpointing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already in add_policy.
Added it to check_multi_agent, which is called from within Algorithm.validate_config. Good catch!

worker_state["policy_config"], serialized_policy_spec["config"]
)
serialized_policy_spec.update({"config": policy_config})
policy_state.update(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel like this is not necessary, if we have already updated serailzed_policy_spec.

>>> a = {"b": {"c": 1}}
>>> p = a["b"]
>>> p["c"] = 2
>>> a
{'b': {'c': 2}}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! Fixed.

policies = Policy.from_checkpoint(path_to_checkpoint)
assert "default_policy" in policies

print(algo.train())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you intend to keep the print() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, prints of train results is pretty common in our tests.

tf1, tf, tfv = try_import_tf()


@PublicAPI(stability="alpha")
def validate_policy_id(policy_id: str, error: bool = False) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please write a unit test for this util?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really necessary? This is a single line, one reg exp match function! ??

from typing import Dict


def dir_contents_to_dict(dir: str) -> Dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment here about we only use this util for unit tests right now, and please be careful using it in production, since it's not really been battle tested.



@PublicAPI(stability="alpha")
def get_checkpoint_info(checkpoint) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think you can start a unit test file for this file?
write a simple checkpoint, then expect the parsed info in certain shape, etc.
this is so it will be a lot more convenient for us to add unit tests in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the whole thing (utils/files.py). It's not need anymore, not even our test cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@sven1977
Copy link
Contributor Author

sven1977 commented Oct 5, 2022

On 2) above:

2 do we really need the additional policy_ids attribute? why can't we simply list(rollout_workers.policy_map.keys())?

We don't have the worker yet! :) And the whole point is to avoid having to generate the entire policy_map before we decide to only create a few policies (then we would have to remove some policies again). One of the main purposes of this PR is to NOT have to create all policies that have been stored in an Algo checkpoint, but to select which subset of policies to kick-start the Algo with, given a checkpoint.

1 similar comment
@sven1977
Copy link
Contributor Author

sven1977 commented Oct 5, 2022

On 2) above:

2 do we really need the additional policy_ids attribute? why can't we simply list(rollout_workers.policy_map.keys())?

We don't have the worker yet! :) And the whole point is to avoid having to generate the entire policy_map before we decide to only create a few policies (then we would have to remove some policies again). One of the main purposes of this PR is to NOT have to create all policies that have been stored in an Algo checkpoint, but to select which subset of policies to kick-start the Algo with, given a checkpoint.

Signed-off-by: sven1977 <[email protected]>
@gjoliver
Copy link
Member

gjoliver commented Oct 5, 2022

On 2) above:

2 do we really need the additional policy_ids attribute? why can't we simply list(rollout_workers.policy_map.keys())?

We don't have the worker yet! :) And the whole point is to avoid having to generate the entire policy_map before we decide to only create a few policies (then we would have to remove some policies again). One of the main purposes of this PR is to NOT have to create all policies that have been stored in an Algo checkpoint, but to select which subset of policies to kick-start the Algo with, given a checkpoint.

ah, yeah, my bad, I got confused from where this is used.

Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok man, thanks for your patience.
let's merge :)

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 merged commit 23b3a59 into ray-project:master Oct 6, 2022
gjoliver pushed a commit to gjoliver/ray that referenced this pull request Oct 10, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants