Skip to content

Commit

Permalink
[RLlib] Properly serialize and restore StateBufferConnector states fo…
Browse files Browse the repository at this point in the history
…r policy stashing (#31372)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
Jun Gong authored Jan 5, 2023
1 parent 5a6b234 commit fba15f6
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 19 deletions.
41 changes: 38 additions & 3 deletions rllib/connectors/agent/state_buffer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
from collections import defaultdict
import logging
import pickle
from typing import Any

import numpy as np
from ray.rllib.utils.annotations import override
import tree # dm_tree

from ray.rllib.connectors.connector import (
AgentConnector,
Connector,
ConnectorContext,
)
from ray import cloudpickle
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI


logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class StateBufferConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext):
def __init__(self, ctx: ConnectorContext, states: Any = None):
super().__init__(ctx)

self._initial_states = ctx.initial_states
self._action_space_struct = get_base_struct_from_space(ctx.action_space)

self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None)))
# TODO(jungong) : we would not need this if policies are never stashed
# during the rollout of a single episode.
if states:
try:
self._states = cloudpickle.loads(states)
except pickle.UnpicklingError:
# StateBufferConnector states are only needed for rare cases
# like stashing then restoring a policy during the rollout of
# a single episode.
# It is ok to ignore the error for most of the cases here.
logger.info(
"Can not restore StateBufferConnector states. This warning can "
"usually be ignore, unless it is from restoring a stashed policy."
)

@override(Connector)
def in_eval(self):
self._states.clear()
super().in_eval()

def reset(self, env_id: str):
# States should not be carried over between episodes.
Expand Down Expand Up @@ -70,11 +98,18 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
return ac_data

def to_state(self):
return StateBufferConnector.__name__, None
# Note(jungong) : it is ok to use cloudpickle here for stats because:
# 1. self._states may contain arbitary data objects, and will be hard
# to serialize otherwise.
# 2. seriazlized states are only useful if a policy is stashed and
# restored during the rollout of a single episode. So it is ok to
# use cloudpickle for such non-persistent data bits.
states = cloudpickle.dumps(self._states)
return StateBufferConnector.__name__, states

@staticmethod
def from_state(ctx: ConnectorContext, params: Any):
return StateBufferConnector(ctx)
return StateBufferConnector(ctx, params)


register_connector(StateBufferConnector.__name__, StateBufferConnector)
5 changes: 3 additions & 2 deletions rllib/connectors/agent/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
"and agent_id({agent_id})"
)

vr = self._view_requirements
assert vr, "ViewRequirements required by ViewRequirementAgentConnector"
assert (
self._view_requirements
), "ViewRequirements required by ViewRequirementAgentConnector"

# Note(jungong) : we need to keep the entire input dict here.
# A column may be used by postprocessing (GAE) even if its
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/collectors/agent_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def build_for_inference(self) -> SampleBatch:
return batch

# TODO: @kouorsh we don't really need view_requirements anymore since it's already
# and attribute of the class
# an attribute of the class
def build_for_training(
self, view_requirements: ViewRequirementsDict
) -> SampleBatch:
Expand Down
4 changes: 1 addition & 3 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def build(self):

class _PolicyCollectorGroup:
def __init__(self, policy_map):
self.policy_collectors = {
pid: _PolicyCollector(policy) for pid, policy in policy_map.items()
}
self.policy_collectors = {}
# Total env-steps (1 env-step=up to N agents stepped).
self.env_steps = 0
# Total agent steps (1 agent-step=1 individual agent (out of N)
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/env_runner_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def _process_policy_eval_results(

# Notify agent connectors with this new policy output.
# Necessary for state buffering agent connectors, for example.
ac_data: AgentConnectorDataType = ActionConnectorDataType(
ac_data: ActionConnectorDataType = ActionConnectorDataType(
env_id,
agent_id,
input_dict,
Expand Down
8 changes: 4 additions & 4 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,14 @@ def restore_connectors(self, state: PolicyState):
self.agent_connectors = restore_connectors_for_policy(
self, connector_configs["agent"]
)
logger.info("restoring agent connectors:")
logger.info(self.agent_connectors.__str__(indentation=4))
logger.debug("restoring agent connectors:")
logger.debug(self.agent_connectors.__str__(indentation=4))
if "action" in connector_configs:
self.action_connectors = restore_connectors_for_policy(
self, connector_configs["action"]
)
logger.info("restoring action connectors:")
logger.info(self.action_connectors.__str__(indentation=4))
logger.debug("restoring action connectors:")
logger.debug(self.action_connectors.__str__(indentation=4))

@DeveloperAPI
@OverrideToImplementCustomLogic_CallToSuperRecommended
Expand Down
5 changes: 4 additions & 1 deletion rllib/policy/policy_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import deque
import threading
from typing import Dict, Set
import logging

import ray
from ray.rllib.policy.policy import Policy
Expand All @@ -12,6 +13,7 @@
from ray.util.annotations import PublicAPI

tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
Expand Down Expand Up @@ -129,9 +131,10 @@ def __getitem__(self, item: PolicyID):
# -> Load new policy's state into the one that just got removed from the cache.
# This way, we save the costly re-creation step.
if policy is not None and self.policy_states_are_swappable:
logger.debug(f"restoring policy: {item}")
policy.set_state(policy_state)
#
else:
logger.debug(f"creating new policy: {item}")
policy = Policy.from_state(policy_state)

self.cache[item] = policy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
register_env("multi_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2}))

# Number of policies overall in the PolicyMap.
num_policies = 100
num_policies = 20
# Number of those policies that should be trained. These are a subset of `num_policies`.
num_trainable = 20
num_trainable = 10

num_envs_per_worker = 5
num_envs_per_worker = 2

# Define the config as an APPOConfig object.
config = (
Expand Down
4 changes: 3 additions & 1 deletion rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
else:
assert x == y, f"ERROR: x ({x}) is not the same as y ({y})!"
# String/byte comparisons.
elif hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith("<U")):
elif (
hasattr(x, "dtype") and (x.dtype == object or str(x.dtype).startswith("<U"))
) or isinstance(x, bytes):
try:
np.testing.assert_array_equal(x, y)
if false is True:
Expand Down

0 comments on commit fba15f6

Please sign in to comment.