Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] [Connector] Fix EnvRunnerV2's handling of soft_horizon episodes. #30434

Merged
merged 3 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 82 additions & 31 deletions rllib/evaluation/env_runner_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,27 @@ def _get_rollout_metrics(self, episode: EpisodeV2) -> List[RolloutMetrics]:
)
]

def __needs_policy_eval(self, agent_done: bool, hit_horizon: bool) -> bool:
"""Decide whether an obs should get queued for policy eval.

Args:
agent_done: Whether the agent is done.
hit_horizon: Whether the env simply hit horizon.

Returns:
Whether this obs should get queued for policy eval.
"""
if hit_horizon:
# Things are pretty tricky.
# We still need to evaluate the obs for action if soft horizon is enabled.
# Note that hit_horizon will only be True if agent itself is not done,
# and __all__ are not done.
return self._soft_horizon
if agent_done:
return False
# Otherwise, agent is alive.
return True

def _process_observations(
self,
unfiltered_obs: MultiEnvDict,
Expand Down Expand Up @@ -665,7 +686,7 @@ def _process_observations(
d.agent_id, d.data.raw_dict
)

if not all_agents_done and not agent_dones[d.agent_id]:
if self.__needs_policy_eval(agent_dones[d.agent_id], hit_horizon):
# Add to eval set if env is not done and this particular agent
# is also not done.
item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
Expand Down Expand Up @@ -758,6 +779,51 @@ def _build_done_episode(
# Clean up and delete the batch_builder.
del self._batch_builders[env_id]

def __process_resetted_obs_for_eval(
self,
env_id: EnvID,
obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
episode: EpisodeV2,
to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
):
"""Process resetted obs through agent connectors for policy eval.

Args:
env_id: The env id.
obs: The Resetted obs.
episode: New episode.
to_eval: List of agent connector data for policy eval.
"""
per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
# types: AgentID, EnvObsType
for agent_id, raw_obs in obs[env_id].items():
policy_id: PolicyID = episode.policy_for(agent_id)
per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))

for policy_id, agents_obs in per_policy_resetted_obs.items():
policy = self._worker.policy_map[policy_id]
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.NEXT_OBS: obs,
SampleBatch.T: episode.length,
},
)
for agent_id, obs in agents_obs
]
# Call agent connectors on these initial obs.
processed = policy.agent_connectors(acd_list)

for d in processed:
episode.add_init_obs(
agent_id=d.agent_id,
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
t=d.data.raw_dict[SampleBatch.T],
)
to_eval[policy_id].append(d)

def _handle_done_episode(
self,
env_id: EnvID,
Expand Down Expand Up @@ -799,7 +865,8 @@ def _handle_done_episode(
self.create_episode(env_id)

# Horizon hit and we have a soft horizon (no hard env reset).
if not is_error and hit_horizon and self._soft_horizon:
soft_reset = not is_error and hit_horizon and self._soft_horizon
if soft_reset:
resetted_obs: Dict[EnvID, Dict[AgentID, EnvObsType]] = {
env_id: env_obs_or_exception
}
Expand All @@ -823,6 +890,7 @@ def _handle_done_episode(
# Reset connector state if this is a hard reset.
for p in self._worker.policy_map.cache.values():
p.agent_connectors.reset(env_id)

# Reset not supported, drop this env from the ready list.
if resetted_obs is None:
if self._horizon != float("inf"):
Expand All @@ -836,35 +904,18 @@ def _handle_done_episode(
new_episode: EpisodeV2 = self._active_episodes[env_id]
self._call_on_episode_start(new_episode, env_id)

per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
# types: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs[env_id].items():
policy_id: PolicyID = new_episode.policy_for(agent_id)
per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))

for policy_id, agents_obs in per_policy_resetted_obs.items():
policy = self._worker.policy_map[policy_id]
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(
env_id,
agent_id,
{
SampleBatch.NEXT_OBS: obs,
SampleBatch.T: new_episode.length,
},
)
for agent_id, obs in agents_obs
]
# Call agent connectors on these initial obs.
processed = policy.agent_connectors(acd_list)

for d in processed:
new_episode.add_init_obs(
agent_id=d.agent_id,
init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
t=d.data.raw_dict[SampleBatch.T],
)
to_eval[policy_id].append(d)
if not soft_reset:
self.__process_resetted_obs_for_eval(
env_id,
resetted_obs,
new_episode,
to_eval,
)
else:
# This Env was soft-reset. to_eval should already have the
# processed obs for this env. Check logics related to
# __needs_policy_eval.
pass

# Step after adding initial obs. This will give us 0 env and agent step.
new_episode.step()
Expand Down
31 changes: 31 additions & 0 deletions rllib/evaluation/tests/test_env_runner_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,37 @@ def on_episode_end(
self.assertEqual(len(outputs), 1)
self.assertTrue(isinstance(outputs[0], RolloutMetrics))

def test_soft_horizon_works(self):
config = (
PPOConfig()
.framework("torch")
.training(
# Specifically ask for a batch of 200 samples.
train_batch_size=200,
)
.rollouts(
num_rollout_workers=0,
num_envs_per_worker=1,
batch_mode="complete_episodes",
rollout_fragment_length=10,
horizon=4,
soft_horizon=True,
# Enable EnvRunnerV2.
enable_connectors=True,
)
)

algo = PPO(config, env=DebugCounterEnv)

rollout_worker = algo.workers.local_worker()
sample_batch = rollout_worker.sample()
sample_batch = convert_ma_batch_to_sample_batch(sample_batch)

# three logical episodes
self.assertEqual(len(set(sample_batch["eps_id"])), 3)
# no real done bits.
self.assertEqual(sum(sample_batch["dones"]), 0)


if __name__ == "__main__":
import sys
Expand Down
4 changes: 2 additions & 2 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,8 @@ def test_reward_clipping(self):
# Clipping: True (clip between -1.0 and 1.0).
config = (
AlgorithmConfig()
.rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
.environment(clip_rewards=True)
.rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
.environment(clip_rewards=True)
)
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
Expand Down