Skip to content

Commit

Permalink
[RLlib] Policy Server/Client metrics reporting fix (#24783)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst authored and maxpumperla committed May 18, 2022
1 parent 14737a2 commit c361850
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 36 deletions.
26 changes: 15 additions & 11 deletions rllib/env/policy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
from typing import Union, Optional
from enum import Enum

import ray.cloudpickle as pickle
from ray.rllib.env.external_env import ExternalEnv
Expand Down Expand Up @@ -36,9 +37,7 @@


@PublicAPI
class PolicyClient:
"""REST client to interact with a RLlib policy server."""

class Commands(Enum):
# Generic commands (for both modes).
ACTION_SPACE = "ACTION_SPACE"
OBSERVATION_SPACE = "OBSERVATION_SPACE"
Expand All @@ -55,6 +54,11 @@ class PolicyClient:
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"


@PublicAPI
class PolicyClient:
"""REST client to interact with an RLlib policy server."""

@PublicAPI
def __init__(
self, address: str, inference_mode: str = "local", update_interval: float = 10.0
Expand Down Expand Up @@ -102,7 +106,7 @@ def start_episode(
return self._send(
{
"episode_id": episode_id,
"command": PolicyClient.START_EPISODE,
"command": Commands.START_EPISODE,
"training_enabled": training_enabled,
}
)["episode_id"]
Expand Down Expand Up @@ -134,7 +138,7 @@ def get_action(
else:
return self._send(
{
"command": PolicyClient.GET_ACTION,
"command": Commands.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
}
Expand All @@ -161,7 +165,7 @@ def log_action(

self._send(
{
"command": PolicyClient.LOG_ACTION,
"command": Commands.LOG_ACTION,
"observation": observation,
"action": action,
"episode_id": episode_id,
Expand Down Expand Up @@ -200,7 +204,7 @@ def log_returns(

self._send(
{
"command": PolicyClient.LOG_RETURNS,
"command": Commands.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
Expand All @@ -225,7 +229,7 @@ def end_episode(

self._send(
{
"command": PolicyClient.END_EPISODE,
"command": Commands.END_EPISODE,
"observation": observation,
"episode_id": episode_id,
}
Expand All @@ -252,7 +256,7 @@ def _setup_local_rollout_worker(self, update_interval):
logger.info("Querying server for rollout worker settings.")
kwargs = self._send(
{
"command": PolicyClient.GET_WORKER_ARGS,
"command": Commands.GET_WORKER_ARGS,
}
)["worker_args"]
(self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
Expand All @@ -269,7 +273,7 @@ def _update_local_policy(self, force=False):
logger.info("Querying server for new policy weights.")
resp = self._send(
{
"command": PolicyClient.GET_WEIGHTS,
"command": Commands.GET_WEIGHTS,
}
)
weights = resp["weights"]
Expand Down Expand Up @@ -311,7 +315,7 @@ def run(self):
)
self.send_fn(
{
"command": PolicyClient.REPORT_SAMPLES,
"command": Commands.REPORT_SAMPLES,
"samples": samples,
"metrics": metrics,
}
Expand Down
92 changes: 67 additions & 25 deletions rllib/env/policy_server_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
import time
import traceback

from typing import List
import ray.cloudpickle as pickle
from ray.rllib.env.policy_client import PolicyClient, _create_embedded_rollout_worker
from ray.rllib.env.policy_client import (
_create_embedded_rollout_worker,
Commands,
)
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.evaluation.sampler import SamplerInput
from ray.rllib.utils.typing import SampleBatchType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,22 +84,56 @@ def __init__(self, ioctx, address, port, idle_timeout=3.0):
self.metrics_queue = queue.Queue()
self.idle_timeout = idle_timeout

def get_metrics():
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed

# Forwards client-reported rewards directly into the local rollout
# worker. This is a bit of a hack since it is patching the get_metrics
# function of the sampler.
# Forwards client-reported metrics directly into the local rollout
# worker.
if self.rollout_worker.sampler is not None:
self.rollout_worker.sampler.get_metrics = get_metrics
# This is a bit of a hack since it is patching the get_metrics
# function of the sampler.

def get_metrics():
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break

return completed

# Create a request handler that receives commands from the clients
self.rollout_worker.sampler.get_metrics = get_metrics
else:
# If there is no sampler, act like if there would be one to collect
# metrics from
class MetricsDummySampler(SamplerInput):
"""This sampler only maintains a queue to get metrics from."""

def __init__(self, metrics_queue):
"""Initializes an AsyncSampler instance.
Args:
metrics_queue: A queue of metrics
"""
self.metrics_queue = metrics_queue

def get_data(self) -> SampleBatchType:
raise NotImplementedError

def get_extra_batches(self) -> List[SampleBatchType]:
raise NotImplementedError

def get_metrics(self) -> List[RolloutMetrics]:
"""Returns metrics computed on a policy client rollout worker."""
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed

self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue)

# Create a request handler that receives commands from the clients
# and sends data and metrics into the queues.
handler = _make_handler(
self.rollout_worker, self.samples_queue, self.metrics_queue
Expand Down Expand Up @@ -153,10 +194,11 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):

def setup_child_rollout_worker():
nonlocal lock
nonlocal child_rollout_worker
nonlocal inference_thread

with lock:
nonlocal child_rollout_worker
nonlocal inference_thread

if child_rollout_worker is None:
(
child_rollout_worker,
Expand Down Expand Up @@ -201,14 +243,14 @@ def execute_command(self, args):
response = {}

# Local inference commands:
if command == PolicyClient.GET_WORKER_ARGS:
if command == Commands.GET_WORKER_ARGS:
logger.info("Sending worker creation args to client.")
response["worker_args"] = rollout_worker.creation_args()
elif command == PolicyClient.GET_WEIGHTS:
elif command == Commands.GET_WEIGHTS:
logger.info("Sending worker weights to client.")
response["weights"] = rollout_worker.get_weights()
response["global_vars"] = rollout_worker.get_global_vars()
elif command == PolicyClient.REPORT_SAMPLES:
elif command == Commands.REPORT_SAMPLES:
logger.info(
"Got sample batch of size {} from client.".format(
args["samples"].count
Expand All @@ -217,23 +259,23 @@ def execute_command(self, args):
report_data(args)

# Remote inference commands:
elif command == PolicyClient.START_EPISODE:
elif command == Commands.START_EPISODE:
setup_child_rollout_worker()
assert inference_thread.is_alive()
response["episode_id"] = child_rollout_worker.env.start_episode(
args["episode_id"], args["training_enabled"]
)
elif command == PolicyClient.GET_ACTION:
elif command == Commands.GET_ACTION:
assert inference_thread.is_alive()
response["action"] = child_rollout_worker.env.get_action(
args["episode_id"], args["observation"]
)
elif command == PolicyClient.LOG_ACTION:
elif command == Commands.LOG_ACTION:
assert inference_thread.is_alive()
child_rollout_worker.env.log_action(
args["episode_id"], args["observation"], args["action"]
)
elif command == PolicyClient.LOG_RETURNS:
elif command == Commands.LOG_RETURNS:
assert inference_thread.is_alive()
if args["done"]:
child_rollout_worker.env.log_returns(
Expand All @@ -243,7 +285,7 @@ def execute_command(self, args):
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"]
)
elif command == PolicyClient.END_EPISODE:
elif command == Commands.END_EPISODE:
assert inference_thread.is_alive()
child_rollout_worker.env.end_episode(
args["episode_id"], args["observation"]
Expand Down

0 comments on commit c361850

Please sign in to comment.