From 34cd1cbccf295db77fcfb5f81421e295f2664643 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 12 Oct 2022 19:25:31 +0200 Subject: [PATCH] [RLlib] Move connectors creation code to single point in RolloutWorker (#29064) Signed-off-by: Artur Niederfahrenhorst --- rllib/evaluation/rollout_worker.py | 64 ++++++++++-------------------- 1 file changed, 21 insertions(+), 43 deletions(-) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 294325eac190..a3854e721af6 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -1,3 +1,4 @@ +from collections import defaultdict import copy import logging import os @@ -40,6 +41,7 @@ from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap +from ray.rllib.utils.filter import NoFilter from ray.rllib.policy.sample_batch import ( DEFAULT_POLICY_ID, MultiAgentBatch, @@ -622,6 +624,8 @@ def wrap(env): f"is ignored." ) + self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter) + self._build_policy_map( self.policy_dict, policy_config, @@ -649,25 +653,6 @@ def wrap(env): f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!" ) - # TODO(jungong) : clean up after non-connector env_runner is fully deprecated. - self.filters: Dict[PolicyID, Filter] = {} - for (policy_id, policy) in self.policy_map.items(): - if not policy_config.get("enable_connectors"): - filter_shape = tree.map_structure( - lambda s: ( - None - if isinstance(s, (Discrete, MultiDiscrete)) # noqa - else np.array(s.shape) - ), - policy.observation_space_struct, - ) - self.filters[policy_id] = get_filter( - self.policy_map[policy_id].config.get( - "observation_filter", "NoFilter" - ), - filter_shape, - ) - if self.worker_index == 0: logger.info("Built filter map: {}".format(self.filters)) @@ -1244,8 +1229,6 @@ def add_policy( """ validate_policy_id(policy_id, error=False) - merged_config = merge_dicts(self.policy_config, config or {}) - if policy_id in self.policy_map: raise KeyError( f"Policy ID '{policy_id}' already exists in policy map! " @@ -1293,26 +1276,6 @@ def add_policy( if policy_state: new_policy.set_state(policy_state) - connectors_enabled = merged_config.get("enable_connectors", False) - - if connectors_enabled: - policy = self.policy_map[policy_id] - create_connectors_for_policy(policy, merged_config) - maybe_get_filters_for_syncing(self, policy_id) - else: - filter_shape = tree.map_structure( - lambda s: ( - None - if isinstance(s, (Discrete, MultiDiscrete)) # noqa - else np.array(s.shape) - ), - new_policy.observation_space_struct, - ) - - self.filters[policy_id] = get_filter( - (config or {}).get("observation_filter", "NoFilter"), filter_shape - ) - self.set_policy_mapping_fn(policy_mapping_fn) if policies_to_train is not None: self.set_is_policy_to_train(policies_to_train) @@ -1905,9 +1868,24 @@ def _build_policy_map( merged_conf, ) - if connectors_enabled and name in self.policy_map: - create_connectors_for_policy(self.policy_map[name], policy_config) + new_policy = self.policy_map[name] + if connectors_enabled: + create_connectors_for_policy(new_policy, merged_conf) maybe_get_filters_for_syncing(self, name) + else: + filter_shape = tree.map_structure( + lambda s: ( + None + if isinstance(s, (Discrete, MultiDiscrete)) # noqa + else np.array(s.shape) + ), + new_policy.observation_space_struct, + ) + + self.filters[name] = get_filter( + (merged_conf or {}).get("observation_filter", "NoFilter"), + filter_shape, + ) if name in self.policy_map: self.callbacks.on_create_policy(