Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Properly serialize and restore StateBufferConnector states for policy stashing #31372

Merged
merged 15 commits into from
Jan 5, 2023
41 changes: 38 additions & 3 deletions rllib/connectors/agent/state_buffer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
from collections import defaultdict
import logging
import pickle
from typing import Any

import numpy as np
from ray.rllib.utils.annotations import override
import tree # dm_tree

from ray.rllib.connectors.connector import (
AgentConnector,
Connector,
ConnectorContext,
)
from ray import cloudpickle
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI


logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class StateBufferConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext):
def __init__(self, ctx: ConnectorContext, states: Any = None):
super().__init__(ctx)

self._initial_states = ctx.initial_states
self._action_space_struct = get_base_struct_from_space(ctx.action_space)

self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None)))
# TODO(jungong) : we would not need this if policies are never stashed
# during the rollout of a single episode.
if states:
try:
self._states = cloudpickle.loads(states)
except pickle.UnpicklingError:
# StateBufferConnector states are only needed for rare cases
# like stashing then restoring a policy during the rollout of
# a single episode.
# It is ok to ignore the error for most of the cases here.
logger.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get why you should not error out all the time here? When will you ever pass in a states object in that is ok if it's not unpickled?

Copy link
Member Author

@gjoliver gjoliver Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you recover a policy for serving, we wouldn't need the state buffer state.
actually this comment reminded me, I am letting StateBufferConnector clear state whenever someone switch it into eval mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we should not pass in the states at all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can discuss a bit.
I feel like the only case we need to restore this is when policies are stashed and recovered in the middle of an episode.
we don't need this state even for training as long as we keep the policy in cache throughout an episode.
so maybe a better fix is to ping any "in-use" poclies, and prevent them from being stashed.

"Can not restore StateBufferConnector states. This warning can "
"usually be ignore, unless it is from restoring a stashed policy."
)

@override(Connector)
def in_eval(self):
self._states.clear()
super().in_eval()

def reset(self, env_id: str):
# States should not be carried over between episodes.
Expand Down Expand Up @@ -70,11 +98,18 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
return ac_data

def to_state(self):
return StateBufferConnector.__name__, None
# Note(jungong) : it is ok to use cloudpickle here for stats because:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to use cloudpickle over pickle is simply that you are pickling a data-structure that contains lambda function, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think we should always use cloudpickle.
It's not officially guaranteed, but there is a less chance of not being able to restore states saved with a higher version of python if we use cloudpickle.
cloudpickle kinda does it automatically for you, using the right pickle library if the version is low (pickle5 etc).

# 1. self._states may contain arbitary data objects, and will be hard
# to serialize otherwise.
# 2. seriazlized states are only useful if a policy is stashed and
# restored during the rollout of a single episode. So it is ok to
# use cloudpickle for such non-persistent data bits.
states = cloudpickle.dumps(self._states)
return StateBufferConnector.__name__, states

@staticmethod
def from_state(ctx: ConnectorContext, params: Any):
return StateBufferConnector(ctx)
return StateBufferConnector(ctx, params)


register_connector(StateBufferConnector.__name__, StateBufferConnector)
5 changes: 3 additions & 2 deletions rllib/connectors/agent/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
"and agent_id({agent_id})"
)

vr = self._view_requirements
assert vr, "ViewRequirements required by ViewRequirementAgentConnector"
assert (
self._view_requirements
), "ViewRequirements required by ViewRequirementAgentConnector"

# Note(jungong) : we need to keep the entire input dict here.
# A column may be used by postprocessing (GAE) even if its
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/collectors/agent_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def build_for_inference(self) -> SampleBatch:
return batch

# TODO: @kouorsh we don't really need view_requirements anymore since it's already
# and attribute of the class
# an attribute of the class
def build_for_training(
self, view_requirements: ViewRequirementsDict
) -> SampleBatch:
Expand Down
4 changes: 1 addition & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def build(self):

class _PolicyCollectorGroup:
def __init__(self, policy_map):
self.policy_collectors = {
pid: _PolicyCollector(policy) for pid, policy in policy_map.items()
}
self.policy_collectors = {}
# Total env-steps (1 env-step=up to N agents stepped).
self.env_steps = 0
# Total agent steps (1 agent-step=1 individual agent (out of N)
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/env_runner_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def _process_policy_eval_results(

# Notify agent connectors with this new policy output.
# Necessary for state buffering agent connectors, for example.
ac_data: AgentConnectorDataType = ActionConnectorDataType(
ac_data: ActionConnectorDataType = ActionConnectorDataType(
env_id,
agent_id,
input_dict,
Expand Down
8 changes: 4 additions & 4 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,14 @@ def restore_connectors(self, state: PolicyState):
self.agent_connectors = restore_connectors_for_policy(
self, connector_configs["agent"]
)
logger.info("restoring agent connectors:")
logger.info(self.agent_connectors.__str__(indentation=4))
logger.debug("restoring agent connectors:")
logger.debug(self.agent_connectors.__str__(indentation=4))
if "action" in connector_configs:
self.action_connectors = restore_connectors_for_policy(
self, connector_configs["action"]
)
logger.info("restoring action connectors:")
logger.info(self.action_connectors.__str__(indentation=4))
logger.debug("restoring action connectors:")
logger.debug(self.action_connectors.__str__(indentation=4))

@DeveloperAPI
@OverrideToImplementCustomLogic_CallToSuperRecommended
Expand Down
5 changes: 4 additions & 1 deletion rllib/policy/policy_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import deque
import threading
from typing import Dict, Set
import logging

import ray
from ray.rllib.policy.policy import Policy
Expand All @@ -12,6 +13,7 @@
from ray.util.annotations import PublicAPI

tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
Expand Down Expand Up @@ -129,9 +131,10 @@ def __getitem__(self, item: PolicyID):
# -> Load new policy's state into the one that just got removed from the cache.
# This way, we save the costly re-creation step.
if policy is not None and self.policy_states_are_swappable:
logger.debug(f"restoring policy: {item}")
policy.set_state(policy_state)
#
else:
logger.debug(f"creating new policy: {item}")
policy = Policy.from_state(policy_state)

self.cache[item] = policy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
register_env("multi_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2}))

# Number of policies overall in the PolicyMap.
num_policies = 100
num_policies = 20
# Number of those policies that should be trained. These are a subset of `num_policies`.
num_trainable = 20
num_trainable = 10

num_envs_per_worker = 5
num_envs_per_worker = 2

# Define the config as an APPOConfig object.
config = (
Expand Down
4 changes: 3 additions & 1 deletion rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
else:
assert x == y, f"ERROR: x ({x}) is not the same as y ({y})!"
# String/byte comparisons.
elif hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith("<U")):
elif (
hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith("<U"))
) or isinstance(x, bytes):
try:
np.testing.assert_array_equal(x, y)
if false is True:
Expand Down