Skip to content

Commit

Permalink
[RLlib] Move connectors creation code to single point in RolloutWorker (
Browse files Browse the repository at this point in the history
#29064)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored Oct 12, 2022
1 parent f2dacc0 commit 34cd1cb
Showing 1 changed file with 21 additions and 43 deletions.
64 changes: 21 additions & 43 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import copy
import logging
import os
Expand Down Expand Up @@ -40,6 +41,7 @@
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.utils.filter import NoFilter
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
Expand Down Expand Up @@ -622,6 +624,8 @@ def wrap(env):
f"is ignored."
)

self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter)

self._build_policy_map(
self.policy_dict,
policy_config,
Expand Down Expand Up @@ -649,25 +653,6 @@ def wrap(env):
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!"
)

# TODO(jungong) : clean up after non-connector env_runner is fully deprecated.
self.filters: Dict[PolicyID, Filter] = {}
for (policy_id, policy) in self.policy_map.items():
if not policy_config.get("enable_connectors"):
filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
policy.observation_space_struct,
)
self.filters[policy_id] = get_filter(
self.policy_map[policy_id].config.get(
"observation_filter", "NoFilter"
),
filter_shape,
)

if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))

Expand Down Expand Up @@ -1244,8 +1229,6 @@ def add_policy(
"""
validate_policy_id(policy_id, error=False)

merged_config = merge_dicts(self.policy_config, config or {})

if policy_id in self.policy_map:
raise KeyError(
f"Policy ID '{policy_id}' already exists in policy map! "
Expand Down Expand Up @@ -1293,26 +1276,6 @@ def add_policy(
if policy_state:
new_policy.set_state(policy_state)

connectors_enabled = merged_config.get("enable_connectors", False)

if connectors_enabled:
policy = self.policy_map[policy_id]
create_connectors_for_policy(policy, merged_config)
maybe_get_filters_for_syncing(self, policy_id)
else:
filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
new_policy.observation_space_struct,
)

self.filters[policy_id] = get_filter(
(config or {}).get("observation_filter", "NoFilter"), filter_shape
)

self.set_policy_mapping_fn(policy_mapping_fn)
if policies_to_train is not None:
self.set_is_policy_to_train(policies_to_train)
Expand Down Expand Up @@ -1905,9 +1868,24 @@ def _build_policy_map(
merged_conf,
)

if connectors_enabled and name in self.policy_map:
create_connectors_for_policy(self.policy_map[name], policy_config)
new_policy = self.policy_map[name]
if connectors_enabled:
create_connectors_for_policy(new_policy, merged_conf)
maybe_get_filters_for_syncing(self, name)
else:
filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
new_policy.observation_space_struct,
)

self.filters[name] = get_filter(
(merged_conf or {}).get("observation_filter", "NoFilter"),
filter_shape,
)

if name in self.policy_map:
self.callbacks.on_create_policy(
Expand Down

0 comments on commit 34cd1cb

Please sign in to comment.