-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Properly serialize and restore StateBufferConnector states for policy stashing #31372
Changes from all commits
24a2414
ed930cb
3abe2e9
eae9710
2dce34e
6395b4c
857538f
daa55a6
063a462
fd4ad78
b4b5b11
be92e9a
69c67d9
40adca3
e44b21e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,56 @@ | ||
from collections import defaultdict | ||
import logging | ||
import pickle | ||
from typing import Any | ||
|
||
import numpy as np | ||
from ray.rllib.utils.annotations import override | ||
import tree # dm_tree | ||
|
||
from ray.rllib.connectors.connector import ( | ||
AgentConnector, | ||
Connector, | ||
ConnectorContext, | ||
) | ||
from ray import cloudpickle | ||
from ray.rllib.connectors.registry import register_connector | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space | ||
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class StateBufferConnector(AgentConnector): | ||
def __init__(self, ctx: ConnectorContext): | ||
def __init__(self, ctx: ConnectorContext, states: Any = None): | ||
super().__init__(ctx) | ||
|
||
self._initial_states = ctx.initial_states | ||
self._action_space_struct = get_base_struct_from_space(ctx.action_space) | ||
|
||
self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None))) | ||
# TODO(jungong) : we would not need this if policies are never stashed | ||
# during the rollout of a single episode. | ||
if states: | ||
try: | ||
self._states = cloudpickle.loads(states) | ||
except pickle.UnpicklingError: | ||
# StateBufferConnector states are only needed for rare cases | ||
# like stashing then restoring a policy during the rollout of | ||
# a single episode. | ||
# It is ok to ignore the error for most of the cases here. | ||
logger.info( | ||
"Can not restore StateBufferConnector states. This warning can " | ||
"usually be ignore, unless it is from restoring a stashed policy." | ||
) | ||
|
||
@override(Connector) | ||
def in_eval(self): | ||
self._states.clear() | ||
super().in_eval() | ||
|
||
def reset(self, env_id: str): | ||
# States should not be carried over between episodes. | ||
|
@@ -70,11 +98,18 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: | |
return ac_data | ||
|
||
def to_state(self): | ||
return StateBufferConnector.__name__, None | ||
# Note(jungong) : it is ok to use cloudpickle here for stats because: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason to use cloudpickle over pickle is simply that you are pickling a data-structure that contains lambda function, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think we should always use cloudpickle. |
||
# 1. self._states may contain arbitary data objects, and will be hard | ||
# to serialize otherwise. | ||
# 2. seriazlized states are only useful if a policy is stashed and | ||
# restored during the rollout of a single episode. So it is ok to | ||
# use cloudpickle for such non-persistent data bits. | ||
states = cloudpickle.dumps(self._states) | ||
return StateBufferConnector.__name__, states | ||
|
||
@staticmethod | ||
def from_state(ctx: ConnectorContext, params: Any): | ||
return StateBufferConnector(ctx) | ||
return StateBufferConnector(ctx, params) | ||
|
||
|
||
register_connector(StateBufferConnector.__name__, StateBufferConnector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get why you should not error out all the time here? When will you ever pass in a states object in that is ok if it's not unpickled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when you recover a policy for serving, we wouldn't need the state buffer state.
actually this comment reminded me, I am letting StateBufferConnector clear state whenever someone switch it into eval mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case we should not pass in the states at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can discuss a bit.
I feel like the only case we need to restore this is when policies are stashed and recovered in the middle of an episode.
we don't need this state even for training as long as we keep the policy in cache throughout an episode.
so maybe a better fix is to ping any "in-use" poclies, and prevent them from being stashed.