From 8b678ddd68a31a49befb4fa4e4b02df2d14a0b3b Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sun, 6 Feb 2022 12:35:03 +0100 Subject: [PATCH] [RLlib] Issue 22036: Client should handle concurrent episodes with one being `training_enabled=False`. (#22076) --- rllib/BUILD | 22 ++++- rllib/env/policy_client.py | 10 +-- .../tests/test_policy_client_server_setup.sh | 16 +++- .../collectors/simple_list_collector.py | 17 +++- rllib/evaluation/sampler.py | 2 +- .../serving/dummy_client_with_two_episodes.py | 90 +++++++++++++++++++ 6 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 rllib/examples/serving/dummy_client_with_two_episodes.py diff --git a/rllib/BUILD b/rllib/BUILD index 17c628f4da27..ad35ddf0fcd3 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1139,6 +1139,24 @@ sh_test( data = glob(["examples/serving/*.py"]), ) +sh_test( + name = "env/tests/test_local_inference_cartpole_w_2_concurrent_episodes", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["local", "cartpole-dummy-2-episodes"], + data = glob(["examples/serving/*.py"]), +) + +sh_test( + name = "env/tests/test_remote_inference_cartpole_w_2_concurrent_episodes", + tags = ["team:ml", "env"], + size = "medium", + srcs = ["env/tests/test_policy_client_server_setup.sh"], + args = ["remote", "cartpole-dummy-2-episodes"], + data = glob(["examples/serving/*.py"]), +) + sh_test( name = "env/tests/test_local_inference_unity3d", tags = ["team:ml", "env"], @@ -2606,9 +2624,9 @@ py_test( name = "examples/recommender_system_with_recsim_and_slateq", main = "examples/recommender_system_with_recsim_and_slateq.py", tags = ["team:ml", "examples", "examples_R"], - size = "medium", + size = "large", srcs = ["examples/recommender_system_with_recsim_and_slateq.py"], - args = ["--stop-iters=3", "--use-tune", "--num-cpus=5", "--random-test-episodes=10", "--env-num-candidates=100", "--env-slate-size=2"], + args = ["--stop-iters=2", "--use-tune", "--num-cpus=5", "--random-test-episodes=10", "--env-num-candidates=100", "--env-slate-size=2"], ) py_test( diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 2c243725ab11..31478eb6016a 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -172,7 +172,7 @@ def log_action( def log_returns( self, episode_id: str, - reward: int, + reward: float, info: Union[EnvInfoDict, MultiAgentDict] = None, multiagent_done_dict: Optional[MultiAgentDict] = None, ) -> None: @@ -183,10 +183,10 @@ def log_returns( logged before the next action, a reward of 0.0 is assumed. Args: - episode_id (str): Episode id returned from start_episode(). - reward (float): Reward from the environment. - info (dict): Extra info dict. - multiagent_done_dict (dict): Multi-agent done information. + episode_id: Episode id returned from start_episode(). + reward: Reward from the environment. + info: Extra info dict. + multiagent_done_dict: Multi-agent done information. """ if self.local: diff --git a/rllib/env/tests/test_policy_client_server_setup.sh b/rllib/env/tests/test_policy_client_server_setup.sh index 4d458ee5b8db..ca902349fe78 100755 --- a/rllib/env/tests/test_policy_client_server_setup.sh +++ b/rllib/env/tests/test_policy_client_server_setup.sh @@ -1,5 +1,9 @@ #!/bin/bash +# Driver script for testing RLlib's client/server setup. +# Run as follows: +# $ test_policy_client_server_setup.sh [inference-mode: local|remote] [env: cartpole|cartpole-dummy-2-episodes|unity3d] + rm -f last_checkpoint.out if [ "$1" == "local" ]; then @@ -8,14 +12,22 @@ else inference_mode=remote fi +# CartPole client/server setup. if [ "$2" == "cartpole" ]; then server_script=cartpole_server.py client_script=cartpole_client.py stop_criterion="--stop-reward=150.0" -else +# Unity3D dummy setup. +elif [ "$2" == "unity3d" ]; then server_script=unity3d_server.py client_script=unity3d_dummy_client.py stop_criterion="--num-episodes=10" +# CartPole dummy test using 2 simultaneous episodes on the client. +# One episode has training_enabled=False (its data should NOT arrive at server). +else + server_script=cartpole_server.py + client_script=dummy_client_with_two_episodes.py + stop_criterion="--dummy-arg=dummy" # no stop criterion: client script terminates either way fi pkill -f $server_script @@ -58,6 +70,6 @@ client2_pid=$! # x reward (CartPole) or n episodes (dummy Unity3D). # Then stop everything. sleep 2 -python $basedir/$client_script $stop_criterion --inference-mode=$inference_mode --port=9901 +python $basedir/$client_script --inference-mode=$inference_mode --port=9901 "$stop_criterion" kill $server_pid $client1_pid $client2_pid || true diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 24e1e035c455..bd0c1451f127 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -710,6 +710,10 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType] keys = self.forward_pass_agent_keys[policy_id] batch_size = len(keys) + # Return empty batch, if no forward pass to do. + if batch_size == 0: + return SampleBatch() + buffers = {} for k in keys: collector = self.agent_collectors[k] @@ -858,6 +862,13 @@ def postprocess_episode( "allow this." ) + # Skip a trajectory's postprocessing (and thus using it for training), + # if its agent's info exists and contains the training_enabled=False + # setting (used by our PolicyClients). + last_info = episode.last_info_for(agent_id) + if last_info and not last_info.get("training_enabled", True): + continue + if len(pre_batches) > 1: other_batches = pre_batches.copy() del other_batches[agent_id] @@ -1026,14 +1037,16 @@ def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID]) -> N """ pid = self.agent_key_to_policy_id[agent_key] - # PID may be a newly added policy. Just confirm we have it in our - # policy map before proceeding with forward_pass_size=0. + # PID may be a newly added policy (added on the fly during training). + # Just confirm we have it in our policy map before proceeding with + # forward_pass_size=0. if pid not in self.forward_pass_size: assert pid in self.policy_map self.forward_pass_size[pid] = 0 self.forward_pass_agent_keys[pid] = [] idx = self.forward_pass_size[pid] + assert idx >= 0 if idx == 0: self.forward_pass_agent_keys[pid].clear() diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index dfead8320627..1d84ac0c0aa6 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -922,7 +922,7 @@ def _process_observations( episode.length - 1, filtered_obs, ) - elif agent_infos is None or agent_infos.get("training_enabled", True): + else: # Add actions, rewards, next-obs to collectors. values_dict = { SampleBatch.T: episode.length - 1, diff --git a/rllib/examples/serving/dummy_client_with_two_episodes.py b/rllib/examples/serving/dummy_client_with_two_episodes.py new file mode 100644 index 000000000000..7acba878004b --- /dev/null +++ b/rllib/examples/serving/dummy_client_with_two_episodes.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +""" +For testing purposes only. +Runs a policy client that starts two episodes, uses one for calculating actions +("action episode") and the other for logging those actions ("logging episode"). +Terminates the "logging episode" before computing a few more actions +from the "action episode". +The action episode is also started with the training_enabled=False flag so no +batches should be produced by this episode for training inside the +SampleCollector's `postprocess_trajectory` method. +""" + +import argparse +import gym +import ray + +from ray.rllib.env.policy_client import PolicyClient + +parser = argparse.ArgumentParser() +parser.add_argument( + "--inference-mode", type=str, default="local", choices=["local", "remote"] +) +parser.add_argument( + "--off-policy", + action="store_true", + help="Whether to compute random actions instead of on-policy " + "(Policy-computed) ones.", +) +parser.add_argument( + "--port", type=int, default=9900, help="The port to use (on localhost)." +) +parser.add_argument("--dummy-arg", type=str, default="") + + +if __name__ == "__main__": + args = parser.parse_args() + + ray.init() + + # Use a CartPole-v0 env so this plays nicely with our cartpole server script. + env = gym.make("CartPole-v0") + + # Note that the RolloutWorker that is generated inside the client (in case + # of local inference) will contain only a RandomEnv dummy env to step through. + # The actual env we care about is the above generated CartPole one. + client = PolicyClient( + f"http://localhost:{args.port}", inference_mode=args.inference_mode + ) + + # Get a dummy obs + dummy_obs = env.reset() + dummy_reward = 1.3 + + # Start an episode to only compute actions (do NOT record this episode's + # trajectories in any returned SampleBatches sent to the server for learning). + action_eid = client.start_episode(training_enabled=False) + print(f"Starting action episode: {action_eid}.") + # Get some actions using the action episode + dummy_action = client.get_action(action_eid, dummy_obs) + print(f"Computing action 1 in action episode: {dummy_action}.") + dummy_action = client.get_action(action_eid, dummy_obs) + print(f"Computing action 2 in action episode: {dummy_action}.") + + # Start a log episode to log action and log rewards for learning. + log_eid = client.start_episode(training_enabled=True) + print(f"Starting logging episode: {log_eid}.") + # Produce an action, just for testing. + garbage_action = client.get_action(log_eid, dummy_obs) + # Log 1 action and 1 reward. + client.log_action(log_eid, dummy_obs, dummy_action) + client.log_returns(log_eid, dummy_reward) + print(f".. logged action + reward: {dummy_action} + {dummy_reward}") + + # Log 2 actions (w/o reward in the middle) and then one reward. + # The reward after the 1st of these actions should be considered 0.0. + client.log_action(log_eid, dummy_obs, dummy_action) + client.log_action(log_eid, dummy_obs, dummy_action) + client.log_returns(log_eid, dummy_reward) + print(f".. logged actions + reward: 2x {dummy_action} + {dummy_reward}") + + # End the log episode + client.end_episode(log_eid, dummy_obs) + print(".. ended logging episode") + + # Continue getting actions using the action episode + # The bug happens when executing the following line + dummy_action = client.get_action(action_eid, dummy_obs) + print(f"Computing action 3 in action episode: {dummy_action}.") + dummy_action = client.get_action(action_eid, dummy_obs) + print(f"Computing action 4 in action episode: {dummy_action}.")