Skip to content

Commit

Permalink
[RLlib] Issue 22036: Client should handle concurrent episodes with on…
Browse files Browse the repository at this point in the history
…e being `training_enabled=False`. (#22076)
  • Loading branch information
sven1977 authored Feb 6, 2022
1 parent fb0d6e6 commit 8b678dd
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 12 deletions.
22 changes: 20 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,24 @@ sh_test(
data = glob(["examples/serving/*.py"]),
)

sh_test(
name = "env/tests/test_local_inference_cartpole_w_2_concurrent_episodes",
tags = ["team:ml", "env"],
size = "medium",
srcs = ["env/tests/test_policy_client_server_setup.sh"],
args = ["local", "cartpole-dummy-2-episodes"],
data = glob(["examples/serving/*.py"]),
)

sh_test(
name = "env/tests/test_remote_inference_cartpole_w_2_concurrent_episodes",
tags = ["team:ml", "env"],
size = "medium",
srcs = ["env/tests/test_policy_client_server_setup.sh"],
args = ["remote", "cartpole-dummy-2-episodes"],
data = glob(["examples/serving/*.py"]),
)

sh_test(
name = "env/tests/test_local_inference_unity3d",
tags = ["team:ml", "env"],
Expand Down Expand Up @@ -2606,9 +2624,9 @@ py_test(
name = "examples/recommender_system_with_recsim_and_slateq",
main = "examples/recommender_system_with_recsim_and_slateq.py",
tags = ["team:ml", "examples", "examples_R"],
size = "medium",
size = "large",
srcs = ["examples/recommender_system_with_recsim_and_slateq.py"],
args = ["--stop-iters=3", "--use-tune", "--num-cpus=5", "--random-test-episodes=10", "--env-num-candidates=100", "--env-slate-size=2"],
args = ["--stop-iters=2", "--use-tune", "--num-cpus=5", "--random-test-episodes=10", "--env-num-candidates=100", "--env-slate-size=2"],
)

py_test(
Expand Down
10 changes: 5 additions & 5 deletions rllib/env/policy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def log_action(
def log_returns(
self,
episode_id: str,
reward: int,
reward: float,
info: Union[EnvInfoDict, MultiAgentDict] = None,
multiagent_done_dict: Optional[MultiAgentDict] = None,
) -> None:
Expand All @@ -183,10 +183,10 @@ def log_returns(
logged before the next action, a reward of 0.0 is assumed.
Args:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Extra info dict.
multiagent_done_dict (dict): Multi-agent done information.
episode_id: Episode id returned from start_episode().
reward: Reward from the environment.
info: Extra info dict.
multiagent_done_dict: Multi-agent done information.
"""

if self.local:
Expand Down
16 changes: 14 additions & 2 deletions rllib/env/tests/test_policy_client_server_setup.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#!/bin/bash

# Driver script for testing RLlib's client/server setup.
# Run as follows:
# $ test_policy_client_server_setup.sh [inference-mode: local|remote] [env: cartpole|cartpole-dummy-2-episodes|unity3d]

rm -f last_checkpoint.out

if [ "$1" == "local" ]; then
Expand All @@ -8,14 +12,22 @@ else
inference_mode=remote
fi

# CartPole client/server setup.
if [ "$2" == "cartpole" ]; then
server_script=cartpole_server.py
client_script=cartpole_client.py
stop_criterion="--stop-reward=150.0"
else
# Unity3D dummy setup.
elif [ "$2" == "unity3d" ]; then
server_script=unity3d_server.py
client_script=unity3d_dummy_client.py
stop_criterion="--num-episodes=10"
# CartPole dummy test using 2 simultaneous episodes on the client.
# One episode has training_enabled=False (its data should NOT arrive at server).
else
server_script=cartpole_server.py
client_script=dummy_client_with_two_episodes.py
stop_criterion="--dummy-arg=dummy" # no stop criterion: client script terminates either way
fi

pkill -f $server_script
Expand Down Expand Up @@ -58,6 +70,6 @@ client2_pid=$!
# x reward (CartPole) or n episodes (dummy Unity3D).
# Then stop everything.
sleep 2
python $basedir/$client_script $stop_criterion --inference-mode=$inference_mode --port=9901
python $basedir/$client_script --inference-mode=$inference_mode --port=9901 "$stop_criterion"

kill $server_pid $client1_pid $client2_pid || true
17 changes: 15 additions & 2 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,10 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType]
keys = self.forward_pass_agent_keys[policy_id]
batch_size = len(keys)

# Return empty batch, if no forward pass to do.
if batch_size == 0:
return SampleBatch()

buffers = {}
for k in keys:
collector = self.agent_collectors[k]
Expand Down Expand Up @@ -858,6 +862,13 @@ def postprocess_episode(
"allow this."
)

# Skip a trajectory's postprocessing (and thus using it for training),
# if its agent's info exists and contains the training_enabled=False
# setting (used by our PolicyClients).
last_info = episode.last_info_for(agent_id)
if last_info and not last_info.get("training_enabled", True):
continue

if len(pre_batches) > 1:
other_batches = pre_batches.copy()
del other_batches[agent_id]
Expand Down Expand Up @@ -1026,14 +1037,16 @@ def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID]) -> N
"""
pid = self.agent_key_to_policy_id[agent_key]

# PID may be a newly added policy. Just confirm we have it in our
# policy map before proceeding with forward_pass_size=0.
# PID may be a newly added policy (added on the fly during training).
# Just confirm we have it in our policy map before proceeding with
# forward_pass_size=0.
if pid not in self.forward_pass_size:
assert pid in self.policy_map
self.forward_pass_size[pid] = 0
self.forward_pass_agent_keys[pid] = []

idx = self.forward_pass_size[pid]
assert idx >= 0
if idx == 0:
self.forward_pass_agent_keys[pid].clear()

Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _process_observations(
episode.length - 1,
filtered_obs,
)
elif agent_infos is None or agent_infos.get("training_enabled", True):
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
SampleBatch.T: episode.length - 1,
Expand Down
90 changes: 90 additions & 0 deletions rllib/examples/serving/dummy_client_with_two_episodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python
"""
For testing purposes only.
Runs a policy client that starts two episodes, uses one for calculating actions
("action episode") and the other for logging those actions ("logging episode").
Terminates the "logging episode" before computing a few more actions
from the "action episode".
The action episode is also started with the training_enabled=False flag so no
batches should be produced by this episode for training inside the
SampleCollector's `postprocess_trajectory` method.
"""

import argparse
import gym
import ray

from ray.rllib.env.policy_client import PolicyClient

parser = argparse.ArgumentParser()
parser.add_argument(
"--inference-mode", type=str, default="local", choices=["local", "remote"]
)
parser.add_argument(
"--off-policy",
action="store_true",
help="Whether to compute random actions instead of on-policy "
"(Policy-computed) ones.",
)
parser.add_argument(
"--port", type=int, default=9900, help="The port to use (on localhost)."
)
parser.add_argument("--dummy-arg", type=str, default="")


if __name__ == "__main__":
args = parser.parse_args()

ray.init()

# Use a CartPole-v0 env so this plays nicely with our cartpole server script.
env = gym.make("CartPole-v0")

# Note that the RolloutWorker that is generated inside the client (in case
# of local inference) will contain only a RandomEnv dummy env to step through.
# The actual env we care about is the above generated CartPole one.
client = PolicyClient(
f"http://localhost:{args.port}", inference_mode=args.inference_mode
)

# Get a dummy obs
dummy_obs = env.reset()
dummy_reward = 1.3

# Start an episode to only compute actions (do NOT record this episode's
# trajectories in any returned SampleBatches sent to the server for learning).
action_eid = client.start_episode(training_enabled=False)
print(f"Starting action episode: {action_eid}.")
# Get some actions using the action episode
dummy_action = client.get_action(action_eid, dummy_obs)
print(f"Computing action 1 in action episode: {dummy_action}.")
dummy_action = client.get_action(action_eid, dummy_obs)
print(f"Computing action 2 in action episode: {dummy_action}.")

# Start a log episode to log action and log rewards for learning.
log_eid = client.start_episode(training_enabled=True)
print(f"Starting logging episode: {log_eid}.")
# Produce an action, just for testing.
garbage_action = client.get_action(log_eid, dummy_obs)
# Log 1 action and 1 reward.
client.log_action(log_eid, dummy_obs, dummy_action)
client.log_returns(log_eid, dummy_reward)
print(f".. logged action + reward: {dummy_action} + {dummy_reward}")

# Log 2 actions (w/o reward in the middle) and then one reward.
# The reward after the 1st of these actions should be considered 0.0.
client.log_action(log_eid, dummy_obs, dummy_action)
client.log_action(log_eid, dummy_obs, dummy_action)
client.log_returns(log_eid, dummy_reward)
print(f".. logged actions + reward: 2x {dummy_action} + {dummy_reward}")

# End the log episode
client.end_episode(log_eid, dummy_obs)
print(".. ended logging episode")

# Continue getting actions using the action episode
# The bug happens when executing the following line
dummy_action = client.get_action(action_eid, dummy_obs)
print(f"Computing action 3 in action episode: {dummy_action}.")
dummy_action = client.get_action(action_eid, dummy_obs)
print(f"Computing action 4 in action episode: {dummy_action}.")

0 comments on commit 8b678dd

Please sign in to comment.