diff --git a/rllib/evaluation/env_runner_v2.py b/rllib/evaluation/env_runner_v2.py index f4860e8e061f..1c6e99cb7ccf 100644 --- a/rllib/evaluation/env_runner_v2.py +++ b/rllib/evaluation/env_runner_v2.py @@ -470,6 +470,27 @@ def _get_rollout_metrics(self, episode: EpisodeV2) -> List[RolloutMetrics]: ) ] + def __needs_policy_eval(self, agent_done: bool, hit_horizon: bool) -> bool: + """Decide whether an obs should get queued for policy eval. + + Args: + agent_done: Whether the agent is done. + hit_horizon: Whether the env simply hit horizon. + + Returns: + Whether this obs should get queued for policy eval. + """ + if hit_horizon: + # Things are pretty tricky. + # We still need to evaluate the obs for action if soft horizon is enabled. + # Note that hit_horizon will only be True if agent itself is not done, + # and __all__ are not done. + return self._soft_horizon + if agent_done: + return False + # Otherwise, agent is alive. + return True + def _process_observations( self, unfiltered_obs: MultiEnvDict, @@ -665,7 +686,7 @@ def _process_observations( d.agent_id, d.data.raw_dict ) - if not all_agents_done and not agent_dones[d.agent_id]: + if self.__needs_policy_eval(agent_dones[d.agent_id], hit_horizon): # Add to eval set if env is not done and this particular agent # is also not done. item = AgentConnectorDataType(d.env_id, d.agent_id, d.data) @@ -758,6 +779,51 @@ def _build_done_episode( # Clean up and delete the batch_builder. del self._batch_builders[env_id] + def __process_resetted_obs_for_eval( + self, + env_id: EnvID, + obs: Dict[EnvID, Dict[AgentID, EnvObsType]], + episode: EpisodeV2, + to_eval: Dict[PolicyID, List[AgentConnectorDataType]], + ): + """Process resetted obs through agent connectors for policy eval. + + Args: + env_id: The env id. + obs: The Resetted obs. + episode: New episode. + to_eval: List of agent connector data for policy eval. + """ + per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list) + # types: AgentID, EnvObsType + for agent_id, raw_obs in obs[env_id].items(): + policy_id: PolicyID = episode.policy_for(agent_id) + per_policy_resetted_obs[policy_id].append((agent_id, raw_obs)) + + for policy_id, agents_obs in per_policy_resetted_obs.items(): + policy = self._worker.policy_map[policy_id] + acd_list: List[AgentConnectorDataType] = [ + AgentConnectorDataType( + env_id, + agent_id, + { + SampleBatch.NEXT_OBS: obs, + SampleBatch.T: episode.length, + }, + ) + for agent_id, obs in agents_obs + ] + # Call agent connectors on these initial obs. + processed = policy.agent_connectors(acd_list) + + for d in processed: + episode.add_init_obs( + agent_id=d.agent_id, + init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS], + t=d.data.raw_dict[SampleBatch.T], + ) + to_eval[policy_id].append(d) + def _handle_done_episode( self, env_id: EnvID, @@ -799,7 +865,8 @@ def _handle_done_episode( self.create_episode(env_id) # Horizon hit and we have a soft horizon (no hard env reset). - if not is_error and hit_horizon and self._soft_horizon: + soft_reset = not is_error and hit_horizon and self._soft_horizon + if soft_reset: resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = { env_id: env_obs_or_exception } @@ -823,6 +890,7 @@ def _handle_done_episode( # Reset connector state if this is a hard reset. for p in self._worker.policy_map.cache.values(): p.agent_connectors.reset(env_id) + # Reset not supported, drop this env from the ready list. if resetted_obs is None: if self._horizon != float("inf"): @@ -836,35 +904,18 @@ def _handle_done_episode( new_episode: EpisodeV2 = self._active_episodes[env_id] self._call_on_episode_start(new_episode, env_id) - per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list) - # types: AgentID, EnvObsType - for agent_id, raw_obs in resetted_obs[env_id].items(): - policy_id: PolicyID = new_episode.policy_for(agent_id) - per_policy_resetted_obs[policy_id].append((agent_id, raw_obs)) - - for policy_id, agents_obs in per_policy_resetted_obs.items(): - policy = self._worker.policy_map[policy_id] - acd_list: List[AgentConnectorDataType] = [ - AgentConnectorDataType( - env_id, - agent_id, - { - SampleBatch.NEXT_OBS: obs, - SampleBatch.T: new_episode.length, - }, - ) - for agent_id, obs in agents_obs - ] - # Call agent connectors on these initial obs. - processed = policy.agent_connectors(acd_list) - - for d in processed: - new_episode.add_init_obs( - agent_id=d.agent_id, - init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS], - t=d.data.raw_dict[SampleBatch.T], - ) - to_eval[policy_id].append(d) + if not soft_reset: + self.__process_resetted_obs_for_eval( + env_id, + resetted_obs, + new_episode, + to_eval, + ) + else: + # This Env was soft-reset. to_eval should already have the + # processed obs for this env. Check logics related to + # __needs_policy_eval. + pass # Step after adding initial obs. This will give us 0 env and agent step. new_episode.step() diff --git a/rllib/evaluation/tests/test_env_runner_v2.py b/rllib/evaluation/tests/test_env_runner_v2.py index eded08ec7b84..c16c5f26a736 100644 --- a/rllib/evaluation/tests/test_env_runner_v2.py +++ b/rllib/evaluation/tests/test_env_runner_v2.py @@ -351,6 +351,37 @@ def on_episode_end( self.assertEqual(len(outputs), 1) self.assertTrue(isinstance(outputs[0], RolloutMetrics)) + def test_soft_horizon_works(self): + config = ( + PPOConfig() + .framework("torch") + .training( + # Specifically ask for a batch of 200 samples. + train_batch_size=200, + ) + .rollouts( + num_rollout_workers=0, + num_envs_per_worker=1, + batch_mode="complete_episodes", + rollout_fragment_length=10, + horizon=4, + soft_horizon=True, + # Enable EnvRunnerV2. + enable_connectors=True, + ) + ) + + algo = PPO(config, env=DebugCounterEnv) + + rollout_worker = algo.workers.local_worker() + sample_batch = rollout_worker.sample() + sample_batch = convert_ma_batch_to_sample_batch(sample_batch) + + # three logical episodes + self.assertEqual(len(set(sample_batch["eps_id"])), 3) + # no real done bits. + self.assertEqual(sum(sample_batch["dones"]), 0) + if __name__ == "__main__": import sys diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 9795958f312d..b201732a6eec 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -487,8 +487,8 @@ def test_reward_clipping(self): # Clipping: True (clip between -1.0 and 1.0). config = ( AlgorithmConfig() - .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") - .environment(clip_rewards=True) + .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") + .environment(clip_rewards=True) ) ev = RolloutWorker( env_creator=lambda _: MockEnv2(episode_length=10),