From ea2bea7e309cd60457aa0e027321be5f10fa0fe5 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 1 Nov 2021 10:59:53 +0100 Subject: [PATCH 01/15] [RLlib; Docs overhaul] Docstring cleanup: Offline. (#19808) --- rllib/evaluation/rollout_worker.py | 7 +- rllib/offline/d4rl_reader.py | 6 +- rllib/offline/input_reader.py | 13 ++-- rllib/offline/io_context.py | 46 +++++++---- rllib/offline/is_estimator.py | 2 +- rllib/offline/json_reader.py | 53 ++++++++----- rllib/offline/json_writer.py | 11 ++- rllib/offline/off_policy_estimator.py | 89 +++++++++++++++++---- rllib/offline/output_writer.py | 8 +- rllib/offline/shuffled_input.py | 6 +- rllib/offline/wis_estimator.py | 108 +++++++++++++------------- 11 files changed, 216 insertions(+), 133 deletions(-) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 2ee532096d82..8fd1ebcf6208 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -664,11 +664,12 @@ def make_sub_env(vector_index): "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": - ise = ImportanceSamplingEstimator.create(self.io_context) + ise = ImportanceSamplingEstimator.\ + create_from_io_context(self.io_context) self.reward_estimators.append(ise) elif method == "wis": - wise = WeightedImportanceSamplingEstimator.create( - self.io_context) + wise = WeightedImportanceSamplingEstimator.\ + create_from_io_context(self.io_context) self.reward_estimators.append(wise) else: raise ValueError( diff --git a/rllib/offline/d4rl_reader.py b/rllib/offline/d4rl_reader.py index d191d65c61f3..dae9cccb019a 100644 --- a/rllib/offline/d4rl_reader.py +++ b/rllib/offline/d4rl_reader.py @@ -17,11 +17,11 @@ class D4RLReader(InputReader): @PublicAPI def __init__(self, inputs: str, ioctx: IOContext = None): - """Initialize a D4RLReader. + """Initializes a D4RLReader instance. Args: - inputs (str): String corresponding to D4RL environment name - ioctx (IOContext): Current IO context object. + inputs: String corresponding to the D4RL environment name. + ioctx: Current IO context object. """ import d4rl self.env = gym.make(inputs) diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index 3b05e4772402..12ac65474027 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -16,15 +16,16 @@ @PublicAPI class InputReader(metaclass=ABCMeta): - """Input object for loading experiences in policy evaluation.""" + """API for collecting and returning experiences during policy evaluation. + """ @abstractmethod @PublicAPI - def next(self): - """Returns the next batch of experiences read. + def next(self) -> SampleBatchType: + """Returns the next batch of read experiences. Returns: - Union[SampleBatch, MultiAgentBatch]: The experience read. + The experience read (SampleBatch or MultiAgentBatch). """ raise NotImplementedError @@ -40,7 +41,7 @@ def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]: reader repeatedly to feed the TensorFlow queue. Args: - queue_size (int): Max elements to allow in the TF queue. + queue_size: Max elements to allow in the TF queue. Example: >>> class MyModel(rllib.model.Model): @@ -56,7 +57,7 @@ def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]: You can find a runnable version of this in examples/custom_loss.py. Returns: - dict of Tensors, one for each column of the read SampleBatch. + Dict of Tensors, one for each column of the read SampleBatch. """ if hasattr(self, "_queue_runner"): diff --git a/rllib/offline/io_context.py b/rllib/offline/io_context.py index f13103b7f295..c74db614c4f3 100644 --- a/rllib/offline/io_context.py +++ b/rllib/offline/io_context.py @@ -1,37 +1,53 @@ import os +from typing import Any, Optional, TYPE_CHECKING from ray.rllib.utils.annotations import PublicAPI -from typing import Any +from ray.rllib.utils.typing import TrainerConfigDict + +if TYPE_CHECKING: + from ray.rllib.evaluation.sampler import SamplerInput @PublicAPI class IOContext: - """Attributes to pass to input / output class constructors. - - RLlib auto-sets these attributes when constructing input / output classes. + """Class containing attributes to pass to input/output class constructors. - Attributes: - log_dir (str): Default logging directory. - config (dict): Configuration of the agent. - worker_index (int): When there are multiple workers created, this - uniquely identifies the current worker. - worker (RolloutWorker): RolloutWorker object reference. - input_config (dict): The input configuration for custom input. + RLlib auto-sets these attributes when constructing input/output classes, + such as InputReaders and OutputWriters. """ @PublicAPI def __init__(self, - log_dir: str = None, - config: dict = None, + log_dir: Optional[str] = None, + config: Optional[TrainerConfigDict] = None, worker_index: int = 0, - worker: Any = None): + worker: Optional[Any] = None): + """Initializes a IOContext object. + + Args: + log_dir: The logging directory to read from/write to. + config: The Trainer's main config dict. + worker_index (int): When there are multiple workers created, this + uniquely identifies the current worker. 0 for the local + worker, >0 for any of the remote workers. + worker (RolloutWorker): The RolloutWorker object reference. + """ self.log_dir = log_dir or os.getcwd() self.config = config or {} self.worker_index = worker_index self.worker = worker @PublicAPI - def default_sampler_input(self) -> Any: + def default_sampler_input(self) -> Optional["SamplerInput"]: + """Returns the RolloutWorker's SamplerInput object, if any. + + Returns None if the RolloutWorker has no SamplerInput. Note that local + workers in case there are also one or more remote workers by default + do not create a SamplerInput object. + + Returns: + The RolloutWorkers' SamplerInput object or None if none exists. + """ return self.worker.sampler @PublicAPI diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index 119eb2e1c97f..242c5f291fa8 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -14,7 +14,7 @@ def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: self.check_can_estimate_for(batch) rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_prob(batch) + new_prob = self.action_log_likelihood(batch) # calculate importance ratios p = [] diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 2177ea27ef98..01da73d87489 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -5,7 +5,7 @@ from pathlib import Path import random import re -from typing import List, Optional +from typing import List, Optional, Union from urllib.parse import urlparse import zipfile @@ -32,17 +32,20 @@ class JsonReader(InputReader): """Reader object that loads experiences from JSON file chunks. - The input files will be read from in an random order.""" + The input files will be read from in random order. + """ @PublicAPI - def __init__(self, inputs: List[str], ioctx: IOContext = None): - """Initialize a JsonReader. + def __init__(self, + inputs: Union[str, List[str]], + ioctx: Optional[IOContext] = None): + """Initializes a JsonReader instance. Args: - inputs (str|list): Either a glob expression for files, e.g., - "/tmp/**/*.json", or a list of single file paths or URIs, e.g., + inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`, + or a list of single file paths or URIs, e.g., ["s3://bucket/file.json", "s3://bucket/file2.json"]. - ioctx (IOContext): Current IO context object. + ioctx: Current IO context object or None. """ self.ioctx = ioctx or IOContext() @@ -72,8 +75,8 @@ def __init__(self, inputs: List[str], ioctx: IOContext = None): self.files = [] for i in inputs: self.files.extend(glob.glob(i)) - elif type(inputs) is list: - self.files = inputs + elif isinstance(inputs, (list, tuple)): + self.files = list(inputs) else: raise ValueError( "type of inputs must be list or str, not {}".format(inputs)) @@ -98,6 +101,26 @@ def next(self) -> SampleBatchType: return self._postprocess_if_needed(batch) + def read_all_files(self) -> SampleBatchType: + """Reads through all files and yields one SampleBatchType per line. + + When reaching the end of the last file, will start from the beginning + again. + + Yields: + One SampleBatch or MultiAgentBatch per line in all input files. + """ + for path in self.files: + file = self._try_open_file(path) + while True: + line = file.readline() + if not line: + break + batch = self._try_parse(line) + if batch is None: + break + yield batch + def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: if not self.ioctx.config.get("postprocess_inputs"): @@ -182,18 +205,6 @@ def _try_parse(self, line: str) -> Optional[SampleBatchType]: self.ioctx.worker.policy_map[pid].action_space_struct) return batch - def read_all_files(self): - for path in self.files: - file = self._try_open_file(path) - while True: - line = file.readline() - if not line: - break - batch = self._try_parse(line) - if batch is None: - break - yield batch - def _next_line(self) -> str: if not self.cur_file: self.cur_file = self._next_file() diff --git a/rllib/offline/json_writer.py b/rllib/offline/json_writer.py index d3c849684e49..77777872dc55 100644 --- a/rllib/offline/json_writer.py +++ b/rllib/offline/json_writer.py @@ -34,15 +34,14 @@ def __init__(self, ioctx: IOContext = None, max_file_size: int = 64 * 1024 * 1024, compress_columns: List[str] = frozenset(["obs", "new_obs"])): - """Initialize a JsonWriter. + """Initializes a JsonWriter instance. Args: - path (str): a path/URI of the output directory to save files in. - ioctx (IOContext): current IO context object. - max_file_size (int): max size of single files before rolling over. - compress_columns (list): list of sample batch columns to compress. + path: a path/URI of the output directory to save files in. + ioctx: current IO context object. + max_file_size: max size of single files before rolling over. + compress_columns: list of sample batch columns to compress. """ - self.ioctx = ioctx or IOContext() self.max_file_size = max_file_size self.compress_columns = compress_columns diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 871ad67e0436..e04f8238682f 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -7,6 +7,7 @@ from ray.rllib.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.offline.io_context import IOContext +from ray.rllib.utils.annotations import Deprecated from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.typing import TensorType, SampleBatchType from typing import List @@ -23,19 +24,30 @@ class OffPolicyEstimator: @DeveloperAPI def __init__(self, policy: Policy, gamma: float): - """Creates an off-policy estimator. + """Initializes an OffPolicyEstimator instance. Args: - policy (Policy): Policy to evaluate. - gamma (float): Discount of the MDP. + policy: Policy to evaluate. + gamma: Discount factor of the environment. """ self.policy = policy self.gamma = gamma self.new_estimates = [] @classmethod - def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": - """Create an off-policy estimator from a IOContext.""" + def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator": + """Creates an off-policy estimator from an IOContext object. + + Extracts Policy and gamma (discount factor) information from the + IOContext. + + Args: + ioctx: The IOContext object to create the OffPolicyEstimator + from. + + Returns: + The OffPolicyEstimator object created from the IOContext object. + """ gamma = ioctx.worker.policy_config["gamma"] # Grab a reference to the current model keys = list(ioctx.worker.policy_map.keys()) @@ -47,18 +59,36 @@ def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": return cls(policy, gamma) @DeveloperAPI - def estimate(self, batch: SampleBatchType): - """Returns an estimate for the given batch of experiences. + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + """Returns an off policy estimate for the given batch of experiences. + + The batch will at most only contain data from one episode, + but it may also only be a fragment of an episode. - The batch will only contain data from one episode, but it may only be - a fragment of an episode. + Args: + batch: The batch to calculate the off policy estimate (OPE) on. + + Returns: + The off-policy estimates (OPE) calculated on the given batch. """ raise NotImplementedError @DeveloperAPI - def action_prob(self, batch: SampleBatchType) -> np.ndarray: - """Returns the probs for the batch actions for the current policy.""" + def action_log_likelihood(self, batch: SampleBatchType) -> TensorType: + """Returns log likelihoods for actions in given batch for policy. + + Computes likelihoods by passing the observations through the current + policy's `compute_log_likelihoods()` method. + + Args: + batch: The SampleBatch or MultiAgentBatch to calculate action + log likelihoods from. This batch/batches must contain OBS + and ACTIONS keys. + Returns: + The log likelihoods of the actions in the batch, given the + observations and the policy. + """ num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): @@ -66,7 +96,7 @@ def action_prob(self, batch: SampleBatchType) -> np.ndarray: state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] log_likelihoods: TensorType = self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], - obs_batch=batch[SampleBatch.CUR_OBS], + obs_batch=batch[SampleBatch.OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), @@ -76,12 +106,29 @@ def action_prob(self, batch: SampleBatchType) -> np.ndarray: return np.exp(log_likelihoods) @DeveloperAPI - def process(self, batch: SampleBatchType): + def process(self, batch: SampleBatchType) -> None: + """Computes off policy estimates (OPE) on batch and stores results. + + Thus-far collected results can be retrieved then by calling + `self.get_metrics` (which flushes the internal results storage). + + Args: + batch: The batch to process (call `self.estimate()` on) and + store results (OPEs) for. + """ self.new_estimates.append(self.estimate(batch)) @DeveloperAPI - def check_can_estimate_for(self, batch: SampleBatchType): - """Returns whether we can support OPE for this batch.""" + def check_can_estimate_for(self, batch: SampleBatchType) -> None: + """Checks if we support off policy estimation (OPE) on given batch. + + Args: + batch: The batch to check. + + Raises: + ValueError: In case `action_prob` key is not in batch OR batch + is a MultiAgentBatch. + """ if isinstance(batch, MultiAgentBatch): raise ValueError( @@ -98,11 +145,19 @@ def check_can_estimate_for(self, batch: SampleBatchType): @DeveloperAPI def get_metrics(self) -> List[OffPolicyEstimate]: - """Return a list of new episode metric estimates since the last call. + """Returns list of new episode metric estimates since the last call. Returns: - list of OffPolicyEstimate objects. + List of OffPolicyEstimate objects. """ out = self.new_estimates self.new_estimates = [] return out + + @Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False) + def create(self, *args, **kwargs): + return self.create_from_io_context(*args, **kwargs) + + @Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False) + def action_prob(self, *args, **kwargs): + return self.action_log_likelihood(*args, **kwargs) diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index 8d168dfb451c..2389c3d741b6 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -1,15 +1,14 @@ -from ray.rllib.utils.annotations import override -from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.typing import SampleBatchType @PublicAPI class OutputWriter: - """Writer object for saving experiences from policy evaluation.""" + """Writer API for saving experiences from policy evaluation.""" @PublicAPI def write(self, sample_batch: SampleBatchType): - """Save a batch of experiences. + """Saves a batch of experiences. Args: sample_batch: SampleBatch or MultiAgentBatch to save. @@ -22,4 +21,5 @@ class NoopOutput(OutputWriter): @override(OutputWriter) def write(self, sample_batch: SampleBatchType): + # Do nothing. pass diff --git a/rllib/offline/shuffled_input.py b/rllib/offline/shuffled_input.py index 24522c87aa2d..a7c261018594 100644 --- a/rllib/offline/shuffled_input.py +++ b/rllib/offline/shuffled_input.py @@ -18,11 +18,11 @@ class ShuffledInput(InputReader): @DeveloperAPI def __init__(self, child: InputReader, n: int = 0): - """Initialize a MixedInput. + """Initializes a ShuffledInput instance. Args: - child (InputReader): child input reader to shuffle. - n (int): if positive, shuffle input over this many batches. + child: child input reader to shuffle. + n: If positive, shuffle input over this many batches. """ self.n = n self.child = child diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index 74eb342a440e..00bbf3145dd8 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -1,54 +1,54 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ - OffPolicyEstimate -from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import SampleBatchType - - -class WeightedImportanceSamplingEstimator(OffPolicyEstimator): - """The weighted step-wise IS estimator. - - Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" - - def __init__(self, policy: Policy, gamma: float): - super().__init__(policy, gamma) - self.filter_values = [] - self.filter_counts = [] - - @override(OffPolicyEstimator) - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: - self.check_can_estimate_for(batch) - - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_prob(batch) - - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) - for t, v in enumerate(p): - if t >= len(self.filter_values): - self.filter_values.append(v) - self.filter_counts.append(1.0) - else: - self.filter_values[t] += v - self.filter_counts[t] += 1.0 - - # calculate stepwise weighted IS estimate - V_prev, V_step_WIS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma**t - w_t = self.filter_values[t] / self.filter_counts[t] - V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t - - estimation = OffPolicyEstimate( - "wis", { - "V_prev": V_prev, - "V_step_WIS": V_step_WIS, - "V_gain_est": V_step_WIS / max(1e-8, V_prev), - }) - return estimation +from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ + OffPolicyEstimate +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import SampleBatchType + + +class WeightedImportanceSamplingEstimator(OffPolicyEstimator): + """The weighted step-wise IS estimator. + + Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" + + def __init__(self, policy: Policy, gamma: float): + super().__init__(policy, gamma) + self.filter_values = [] + self.filter_counts = [] + + @override(OffPolicyEstimator) + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + + rewards, old_prob = batch["rewards"], batch["action_prob"] + new_prob = self.action_log_likelihood(batch) + + # calculate importance ratios + p = [] + for t in range(batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) + for t, v in enumerate(p): + if t >= len(self.filter_values): + self.filter_values.append(v) + self.filter_counts.append(1.0) + else: + self.filter_values[t] += v + self.filter_counts[t] += 1.0 + + # calculate stepwise weighted IS estimate + V_prev, V_step_WIS = 0.0, 0.0 + for t in range(batch.count): + V_prev += rewards[t] * self.gamma**t + w_t = self.filter_values[t] / self.filter_counts[t] + V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t + + estimation = OffPolicyEstimate( + "wis", { + "V_prev": V_prev, + "V_step_WIS": V_step_WIS, + "V_gain_est": V_step_WIS / max(1e-8, V_prev), + }) + return estimation From ee57025be6d0f996223ca892b9d61fbf8f63545f Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 1 Nov 2021 12:24:02 -0500 Subject: [PATCH 02/15] [serve] Rename BackendConfig -> DeploymentConfig (#19923) --- .../java/io/ray/serve/DeploymentInfo.java | 10 +-- .../java/io/ray/serve/RayServeReplica.java | 14 ++--- .../io/ray/serve/RayServeWrappedReplica.java | 18 +++--- .../main/java/io/ray/serve/ReplicaSet.java | 6 +- .../src/main/java/io/ray/serve/Router.java | 2 +- .../io/ray/serve/poll/LongPollClient.java | 2 +- .../io/ray/serve/util/ServeProtoUtil.java | 39 ++++++------ .../java/io/ray/serve/ProxyActorTest.java | 4 +- .../java/io/ray/serve/RayServeHandleTest.java | 12 ++-- .../io/ray/serve/RayServeReplicaTest.java | 12 ++-- .../java/io/ray/serve/ReplicaSetTest.java | 14 ++--- .../test/java/io/ray/serve/RouterTest.java | 12 ++-- .../io/ray/serve/poll/LongPollClientTest.java | 16 ++--- python/ray/serve/api.py | 60 +++++++++--------- python/ray/serve/backend_state.py | 42 +++++++------ python/ray/serve/common.py | 6 +- python/ray/serve/config.py | 61 +++++++++---------- python/ray/serve/controller.py | 30 ++++----- python/ray/serve/replica.py | 43 ++++++------- python/ray/serve/tests/test_backend_state.py | 9 +-- python/ray/serve/tests/test_config.py | 30 ++++----- src/ray/protobuf/serve.proto | 28 ++++----- 22 files changed, 239 insertions(+), 231 deletions(-) diff --git a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java index 0a6117b21133..943be34514c8 100644 --- a/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java +++ b/java/serve/src/main/java/io/ray/serve/DeploymentInfo.java @@ -6,18 +6,18 @@ public class DeploymentInfo implements Serializable { private static final long serialVersionUID = -4198364411759931955L; - private byte[] backendConfig; + private byte[] deploymentConfig; private ReplicaConfig replicaConfig; private byte[] deploymentVersion; - public byte[] getBackendConfig() { - return backendConfig; + public byte[] getDeploymentConfig() { + return deploymentConfig; } - public void setBackendConfig(byte[] backendConfig) { - this.backendConfig = backendConfig; + public void setDeploymentConfig(byte[] deploymentConfig) { + this.deploymentConfig = deploymentConfig; } public ReplicaConfig getReplicaConfig() { diff --git a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java index c8445ca94238..7151f175fd65 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeReplica.java @@ -9,7 +9,7 @@ import io.ray.runtime.metric.Metrics; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; -import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.DeploymentConfig; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.RequestWrapper; import io.ray.serve.poll.KeyListener; @@ -34,7 +34,7 @@ public class RayServeReplica { private String replicaTag; - private BackendConfig config; + private DeploymentConfig config; private AtomicInteger numOngoingRequests = new AtomicInteger(); @@ -58,19 +58,19 @@ public class RayServeReplica { public RayServeReplica( Object callable, - BackendConfig backendConfig, + DeploymentConfig deploymentConfig, DeploymentVersion version, BaseActorHandle actorHandle) { this.backendTag = Serve.getReplicaContext().getBackendTag(); this.replicaTag = Serve.getReplicaContext().getReplicaTag(); this.callable = callable; - this.config = backendConfig; + this.config = deploymentConfig; this.version = version; Map keyListeners = new HashMap<>(); keyListeners.put( new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), - newConfig -> updateBackendConfigs(newConfig)); + newConfig -> updateDeploymentConfigs(newConfig)); this.longPollClient = new LongPollClient(actorHandle, keyListeners); this.longPollClient.start(); registerMetrics(); @@ -319,8 +319,8 @@ public DeploymentVersion reconfigure(Object userConfig) { * * @param newConfig the new configuration of backend */ - private void updateBackendConfigs(Object newConfig) { - config = (BackendConfig) newConfig; + private void updateDeploymentConfigs(Object newConfig) { + config = (DeploymentConfig) newConfig; } public DeploymentVersion getVersion() { diff --git a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java index e3228c9c0aaf..545feb81da94 100644 --- a/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java +++ b/java/serve/src/main/java/io/ray/serve/RayServeWrappedReplica.java @@ -6,7 +6,7 @@ import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; -import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.DeploymentConfig; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.util.ReflectUtil; @@ -27,17 +27,17 @@ public RayServeWrappedReplica( String replicaTag, String backendDef, byte[] initArgsbytes, - byte[] backendConfigBytes, + byte[] deploymentConfigBytes, byte[] deploymentVersionBytes, String controllerName) throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException { - // Parse BackendConfig. - BackendConfig backendConfig = ServeProtoUtil.parseBackendConfig(backendConfigBytes); + // Parse DeploymentConfig. + DeploymentConfig deploymentConfig = ServeProtoUtil.parseDeploymentConfig(deploymentConfigBytes); // Parse init args. - Object[] initArgs = parseInitArgs(initArgsbytes, backendConfig); + Object[] initArgs = parseInitArgs(initArgsbytes, deploymentConfig); // Instantiate the object defined by backendDef. Class backendClass = Class.forName(backendDef); @@ -57,7 +57,7 @@ public RayServeWrappedReplica( backend = new RayServeReplica( callable, - backendConfig, + deploymentConfig, ServeProtoUtil.parseDeploymentVersion(deploymentVersionBytes), optional.get()); } @@ -71,19 +71,19 @@ public RayServeWrappedReplica( replicaTag, deploymentInfo.getReplicaConfig().getBackendDef(), deploymentInfo.getReplicaConfig().getInitArgs(), - deploymentInfo.getBackendConfig(), + deploymentInfo.getDeploymentConfig(), deploymentInfo.getDeploymentVersion(), controllerName); } - private Object[] parseInitArgs(byte[] initArgsbytes, BackendConfig backendConfig) + private Object[] parseInitArgs(byte[] initArgsbytes, DeploymentConfig deploymentConfig) throws IOException { if (initArgsbytes == null || initArgsbytes.length == 0) { return new Object[0]; } - if (!backendConfig.getIsCrossLanguage()) { + if (!deploymentConfig.getIsCrossLanguage()) { // If the construction request is from Java API, deserialize initArgsbytes to Object[] // directly. return MessagePackSerializer.decode(initArgsbytes, Object[].class); diff --git a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java index 1c7e757bba44..0b5646f4bd5c 100644 --- a/java/serve/src/main/java/io/ray/serve/ReplicaSet.java +++ b/java/serve/src/main/java/io/ray/serve/ReplicaSet.java @@ -9,7 +9,7 @@ import io.ray.runtime.metric.Metrics; import io.ray.runtime.metric.TagKey; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.DeploymentConfig; import io.ray.serve.util.CollectionUtil; import java.util.ArrayList; import java.util.HashSet; @@ -48,8 +48,8 @@ public ReplicaSet(String backendTag) { .register()); } - public void setMaxConcurrentQueries(Object backendConfig) { - int newValue = ((BackendConfig) backendConfig).getMaxConcurrentQueries(); + public void setMaxConcurrentQueries(Object deploymentConfig) { + int newValue = ((DeploymentConfig) deploymentConfig).getMaxConcurrentQueries(); if (newValue != this.maxConcurrentQueries) { this.maxConcurrentQueries = newValue; LOGGER.info("ReplicaSet: changing max_concurrent_queries to {}", newValue); diff --git a/java/serve/src/main/java/io/ray/serve/Router.java b/java/serve/src/main/java/io/ray/serve/Router.java index 5ef339d77767..0b744c835a09 100644 --- a/java/serve/src/main/java/io/ray/serve/Router.java +++ b/java/serve/src/main/java/io/ray/serve/Router.java @@ -38,7 +38,7 @@ public Router(BaseActorHandle controllerHandle, String backendTag) { Map keyListeners = new HashMap<>(); keyListeners.put( new KeyType(LongPollNamespace.BACKEND_CONFIGS, backendTag), - backendConfig -> replicaSet.setMaxConcurrentQueries(backendConfig)); // cross language + deploymentConfig -> replicaSet.setMaxConcurrentQueries(deploymentConfig)); // cross language keyListeners.put( new KeyType(LongPollNamespace.REPLICA_HANDLES, backendTag), workerReplicas -> replicaSet.updateWorkerReplicas(workerReplicas)); // cross language diff --git a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java index 308391254e10..54d107b717aa 100644 --- a/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java +++ b/java/serve/src/main/java/io/ray/serve/poll/LongPollClient.java @@ -47,7 +47,7 @@ public class LongPollClient { static { DESERIALIZERS.put( - LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseBackendConfig(body)); + LongPollNamespace.BACKEND_CONFIGS, body -> ServeProtoUtil.parseDeploymentConfig(body)); DESERIALIZERS.put( LongPollNamespace.REPLICA_HANDLES, body -> ServeProtoUtil.parseEndpointSet(body)); DESERIALIZERS.put( diff --git a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java index eeee061a9fb8..edcc8399e101 100644 --- a/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java +++ b/java/serve/src/main/java/io/ray/serve/util/ServeProtoUtil.java @@ -7,8 +7,8 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.Constants; import io.ray.serve.RayServeException; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentConfig; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.EndpointInfo; import io.ray.serve.generated.EndpointSet; @@ -25,23 +25,23 @@ public class ServeProtoUtil { private static final Gson GSON = new Gson(); - public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { + public static DeploymentConfig parseDeploymentConfig(byte[] deploymentConfigBytes) { - // Get a builder from BackendConfig(bytes) or create a new one. - BackendConfig.Builder builder = null; - if (backendConfigBytes == null) { - builder = BackendConfig.newBuilder(); + // Get a builder from DeploymentConfig(bytes) or create a new one. + DeploymentConfig.Builder builder = null; + if (deploymentConfigBytes == null) { + builder = DeploymentConfig.newBuilder(); } else { - BackendConfig backendConfig = null; + DeploymentConfig deploymentConfig = null; try { - backendConfig = BackendConfig.parseFrom(backendConfigBytes); + deploymentConfig = DeploymentConfig.parseFrom(deploymentConfigBytes); } catch (InvalidProtocolBufferException e) { - throw new RayServeException("Failed to parse BackendConfig from protobuf bytes.", e); + throw new RayServeException("Failed to parse DeploymentConfig from protobuf bytes.", e); } - if (backendConfig == null) { - builder = BackendConfig.newBuilder(); + if (deploymentConfig == null) { + builder = DeploymentConfig.newBuilder(); } else { - builder = BackendConfig.newBuilder(backendConfig); + builder = DeploymentConfig.newBuilder(deploymentConfig); } } @@ -64,22 +64,23 @@ public static BackendConfig parseBackendConfig(byte[] backendConfigBytes) { builder.setGracefulShutdownTimeoutS(20); } - if (builder.getBackendLanguage() == BackendLanguage.UNRECOGNIZED) { + if (builder.getDeploymentLanguage() == DeploymentLanguage.UNRECOGNIZED) { throw new RayServeException( LogUtil.format( "Unrecognized backend language {}. Backend language must be in {}.", - builder.getBackendLanguageValue(), - Lists.newArrayList(BackendLanguage.values()))); + builder.getDeploymentLanguageValue(), + Lists.newArrayList(DeploymentLanguage.values()))); } return builder.build(); } - public static Object parseUserConfig(BackendConfig backendConfig) { - if (backendConfig.getUserConfig() == null || backendConfig.getUserConfig().size() == 0) { + public static Object parseUserConfig(DeploymentConfig deploymentConfig) { + if (deploymentConfig.getUserConfig() == null || deploymentConfig.getUserConfig().size() == 0) { return null; } - return MessagePackSerializer.decode(backendConfig.getUserConfig().toByteArray(), Object.class); + return MessagePackSerializer.decode( + deploymentConfig.getUserConfig().toByteArray(), Object.class); } public static RequestMetadata parseRequestMetadata(byte[] requestMetadataBytes) diff --git a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java index 632e2805b962..69f7e0ed987d 100644 --- a/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java +++ b/java/serve/src/test/java/io/ray/serve/ProxyActorTest.java @@ -5,7 +5,7 @@ import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.api.Serve; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.DeploymentConfig; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.EndpointInfo; import io.ray.serve.util.CommonUtil; @@ -51,7 +51,7 @@ public void test() throws IOException { // Replica DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(BackendConfig.newBuilder().build().toByteArray()); + deploymentInfo.setDeploymentConfig(DeploymentConfig.newBuilder().build().toByteArray()); deploymentInfo.setDeploymentVersion( DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray()); deploymentInfo.setReplicaConfig( diff --git a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java index ae2a6e22dddd..243ea1d5657b 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeHandleTest.java @@ -5,8 +5,8 @@ import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentConfig; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.DeploymentVersion; import java.util.HashMap; import org.testng.Assert; @@ -31,15 +31,15 @@ public void test() { Ray.actor(DummyServeController::new).setName(controllerName).remote(); // Replica - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder(); + deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA); + byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray(); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setDeploymentConfig(deploymentConfigBytes); deploymentInfo.setDeploymentVersion( DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray()); deploymentInfo.setReplicaConfig( diff --git a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java index 7c0e61b95221..7e9d5d5e5577 100644 --- a/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java +++ b/java/serve/src/test/java/io/ray/serve/RayServeReplicaTest.java @@ -4,8 +4,8 @@ import io.ray.api.ObjectRef; import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentConfig; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.RequestMetadata; import io.ray.serve.generated.RequestWrapper; @@ -32,14 +32,14 @@ public void test() throws IOException { ActorHandle controllerHandle = Ray.actor(DummyServeController::new).setName(controllerName).remote(); - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder(); + deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA); + byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray(); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setDeploymentConfig(deploymentConfigBytes); deploymentInfo.setDeploymentVersion( DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray()); deploymentInfo.setReplicaConfig( diff --git a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java index 972d357dd114..d70edcec3082 100644 --- a/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java +++ b/java/serve/src/test/java/io/ray/serve/ReplicaSetTest.java @@ -5,8 +5,8 @@ import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentConfig; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.RequestMetadata; import java.util.HashMap; @@ -23,7 +23,7 @@ public class ReplicaSetTest { @Test public void setMaxConcurrentQueriesTest() { ReplicaSet replicaSet = new ReplicaSet(backendTag); - BackendConfig.Builder builder = BackendConfig.newBuilder(); + DeploymentConfig.Builder builder = DeploymentConfig.newBuilder(); builder.setMaxConcurrentQueries(200); replicaSet.setMaxConcurrentQueries(builder.build()); @@ -58,15 +58,15 @@ public void assignReplicaTest() { Ray.actor(DummyServeController::new).setName(controllerName).remote(); // Replica - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder(); + deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA); + byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray(); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setDeploymentConfig(deploymentConfigBytes); deploymentInfo.setDeploymentVersion( DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray()); deploymentInfo.setReplicaConfig( diff --git a/java/serve/src/test/java/io/ray/serve/RouterTest.java b/java/serve/src/test/java/io/ray/serve/RouterTest.java index f52f8822eaba..03ee6da4a6ef 100644 --- a/java/serve/src/test/java/io/ray/serve/RouterTest.java +++ b/java/serve/src/test/java/io/ray/serve/RouterTest.java @@ -5,8 +5,8 @@ import io.ray.api.Ray; import io.ray.runtime.serializer.MessagePackSerializer; import io.ray.serve.generated.ActorSet; -import io.ray.serve.generated.BackendConfig; -import io.ray.serve.generated.BackendLanguage; +import io.ray.serve.generated.DeploymentConfig; +import io.ray.serve.generated.DeploymentLanguage; import io.ray.serve.generated.DeploymentVersion; import io.ray.serve.generated.RequestMetadata; import java.util.HashMap; @@ -33,15 +33,15 @@ public void test() { Ray.actor(DummyServeController::new).setName(controllerName).remote(); // Replica - BackendConfig.Builder backendConfigBuilder = BackendConfig.newBuilder(); - backendConfigBuilder.setBackendLanguage(BackendLanguage.JAVA); - byte[] backendConfigBytes = backendConfigBuilder.build().toByteArray(); + DeploymentConfig.Builder deploymentConfigBuilder = DeploymentConfig.newBuilder(); + deploymentConfigBuilder.setDeploymentLanguage(DeploymentLanguage.JAVA); + byte[] deploymentConfigBytes = deploymentConfigBuilder.build().toByteArray(); Object[] initArgs = new Object[] {backendTag, replicaTag, controllerName, new Object()}; byte[] initArgsBytes = MessagePackSerializer.encode(initArgs).getLeft(); DeploymentInfo deploymentInfo = new DeploymentInfo(); - deploymentInfo.setBackendConfig(backendConfigBytes); + deploymentInfo.setDeploymentConfig(deploymentConfigBytes); deploymentInfo.setDeploymentVersion( DeploymentVersion.newBuilder().setCodeVersion(version).build().toByteArray()); deploymentInfo.setReplicaConfig( diff --git a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java index 7ee254806fad..6bda0c277907 100644 --- a/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java +++ b/java/serve/src/test/java/io/ray/serve/poll/LongPollClientTest.java @@ -1,7 +1,7 @@ package io.ray.serve.poll; import com.google.protobuf.ByteString; -import io.ray.serve.generated.BackendConfig; +import io.ray.serve.generated.DeploymentConfig; import io.ray.serve.generated.UpdatedObject; import java.util.HashMap; import java.util.Map; @@ -19,18 +19,18 @@ public void test() throws Throwable { KeyType keyType = new KeyType(LongPollNamespace.BACKEND_CONFIGS, "backendTag"); Map keyListeners = new HashMap<>(); keyListeners.put( - keyType, (object) -> a[0] = String.valueOf(((BackendConfig) object).getNumReplicas())); + keyType, (object) -> a[0] = String.valueOf(((DeploymentConfig) object).getNumReplicas())); // Initialize LongPollClient. LongPollClient longPollClient = new LongPollClient(null, keyListeners); // Construct updated object. - BackendConfig.Builder backendConfig = BackendConfig.newBuilder(); - backendConfig.setNumReplicas(20); + DeploymentConfig.Builder deploymentConfig = DeploymentConfig.newBuilder(); + deploymentConfig.setNumReplicas(20); int snapshotId = 10; UpdatedObject.Builder updatedObject = UpdatedObject.newBuilder(); updatedObject.setSnapshotId(snapshotId); - updatedObject.setObjectSnapshot(ByteString.copyFrom(backendConfig.build().toByteArray())); + updatedObject.setObjectSnapshot(ByteString.copyFrom(deploymentConfig.build().toByteArray())); // Process update. Map updates = new HashMap<>(); @@ -40,8 +40,8 @@ public void test() throws Throwable { // Validation. Assert.assertEquals(longPollClient.getSnapshotIds().get(keyType).intValue(), snapshotId); Assert.assertEquals( - ((BackendConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), - backendConfig.getNumReplicas()); - Assert.assertEquals(a[0], String.valueOf(backendConfig.getNumReplicas())); + ((DeploymentConfig) longPollClient.getObjectSnapshots().get(keyType)).getNumReplicas(), + deploymentConfig.getNumReplicas()); + Assert.assertEquals(a[0], String.valueOf(deploymentConfig.getNumReplicas())); } } diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 06bacdf2b2d1..952e9aee57fa 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -17,7 +17,7 @@ from ray.actor import ActorHandle from ray.serve.common import BackendInfo, GoalId, ReplicaTag -from ray.serve.config import (AutoscalingConfig, BackendConfig, HTTPOptions, +from ray.serve.config import (AutoscalingConfig, DeploymentConfig, HTTPOptions, ReplicaConfig) from ray.serve.constants import (DEFAULT_CHECKPOINT_PATH, HTTP_PROXY_TIMEOUT, SERVE_CONTROLLER_NAME, MAX_CACHED_HANDLES, @@ -187,18 +187,19 @@ def _wait_for_goal(self, return False @_ensure_connected - def deploy(self, - name: str, - backend_def: Union[Callable, Type[Callable], str], - init_args: Tuple[Any], - init_kwargs: Dict[Any, Any], - ray_actor_options: Optional[Dict] = None, - config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, - version: Optional[str] = None, - prev_version: Optional[str] = None, - route_prefix: Optional[str] = None, - url: str = "", - _blocking: Optional[bool] = True) -> Optional[GoalId]: + def deploy( + self, + name: str, + deployment_def: Union[Callable, Type[Callable], str], + init_args: Tuple[Any], + init_kwargs: Dict[Any, Any], + ray_actor_options: Optional[Dict] = None, + config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, + version: Optional[str] = None, + prev_version: Optional[str] = None, + route_prefix: Optional[str] = None, + url: str = "", + _blocking: Optional[bool] = True) -> Optional[GoalId]: if config is None: config = {} if ray_actor_options is None: @@ -212,23 +213,25 @@ def deploy(self, ray_actor_options["runtime_env"] = curr_job_env replica_config = ReplicaConfig( - backend_def, + deployment_def, init_args=init_args, init_kwargs=init_kwargs, ray_actor_options=ray_actor_options) if isinstance(config, dict): - backend_config = BackendConfig.parse_obj(config) - elif isinstance(config, BackendConfig): - backend_config = config + deployment_config = DeploymentConfig.parse_obj(config) + elif isinstance(config, DeploymentConfig): + deployment_config = config else: - raise TypeError("config must be a BackendConfig or a dictionary.") + raise TypeError( + "config must be a DeploymentConfig or a dictionary.") goal_id, updating = ray.get( - self._controller.deploy.remote( - name, backend_config.to_proto_bytes(), replica_config, version, - prev_version, route_prefix, - ray.get_runtime_context().job_id)) + self._controller.deploy.remote(name, + deployment_config.to_proto_bytes(), + replica_config, version, + prev_version, route_prefix, + ray.get_runtime_context().job_id)) tag = f"component=serve deployment={name}" @@ -626,7 +629,7 @@ class Deployment: def __init__(self, func_or_class: Callable, name: str, - config: BackendConfig, + config: DeploymentConfig, version: Optional[str] = None, prev_version: Optional[str] = None, init_args: Optional[Tuple[Any]] = None, @@ -1021,7 +1024,7 @@ class MyDeployment: raise ValueError("Manually setting num_replicas is not allowed when " "_autoscaling_config is provided.") - config = BackendConfig() + config = DeploymentConfig() if num_replicas is not None: config.num_replicas = num_replicas @@ -1085,9 +1088,10 @@ def get_deployment(name: str) -> Deployment: raise KeyError(f"Deployment {name} was not found. " "Did you call Deployment.deploy()?") return Deployment( - cloudpickle.loads(backend_info.replica_config.serialized_backend_def), + cloudpickle.loads( + backend_info.replica_config.serialized_deployment_def), name, - backend_info.backend_config, + backend_info.deployment_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, init_kwargs=backend_info.replica_config.init_kwargs, @@ -1109,9 +1113,9 @@ def list_deployments() -> Dict[str, Deployment]: for name, (backend_info, route_prefix) in infos.items(): deployments[name] = Deployment( cloudpickle.loads( - backend_info.replica_config.serialized_backend_def), + backend_info.replica_config.serialized_deployment_def), name, - backend_info.backend_config, + backend_info.deployment_config, version=backend_info.version, init_args=backend_info.replica_config.init_args, init_kwargs=backend_info.replica_config.init_kwargs, diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 8f23bc3fba75..7b00235f34af 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -11,7 +11,7 @@ from ray.serve.async_goal_manager import AsyncGoalManager from ray.serve.common import (BackendInfo, BackendTag, Duration, GoalId, ReplicaTag, ReplicaName, RunningReplicaInfo) -from ray.serve.config import BackendConfig +from ray.serve.config import DeploymentConfig from ray.serve.constants import ( CONTROLLER_STARTUP_GRACE_PERIOD_S, SERVE_CONTROLLER_NAME, SERVE_PROXY_NAME, MAX_DEPLOYMENT_CONSTRUCTOR_RETRY_COUNT, MAX_NUM_DELETED_DEPLOYMENTS) @@ -156,9 +156,9 @@ def start(self, backend_info: BackendInfo, version: DeploymentVersion): """ self._actor_resources = backend_info.replica_config.resource_dict self._max_concurrent_queries = ( - backend_info.backend_config.max_concurrent_queries) + backend_info.deployment_config.max_concurrent_queries) self._graceful_shutdown_timeout_s = ( - backend_info.backend_config.graceful_shutdown_timeout_s) + backend_info.deployment_config.graceful_shutdown_timeout_s) if USE_PLACEMENT_GROUP: self._placement_group = self.create_placement_group( self._placement_group_name, self._actor_resources) @@ -177,11 +177,11 @@ def start(self, backend_info: BackendInfo, version: DeploymentVersion): self.backend_tag, self.replica_tag, backend_info.replica_config.init_args, backend_info.replica_config.init_kwargs, - backend_info.backend_config.to_proto_bytes(), version, + backend_info.deployment_config.to_proto_bytes(), version, self._controller_name, self._detached) self._ready_obj_ref = self._actor_handle.reconfigure.remote( - backend_info.backend_config.user_config) + backend_info.deployment_config.user_config) def update_user_config(self, user_config: Any): """ @@ -242,11 +242,11 @@ def check_ready( return ReplicaStartupStatus.PENDING, None elif len(ready) > 0: try: - backend_config, version = ray.get(ready)[0] + deployment_config, version = ray.get(ready)[0] self._max_concurrent_queries = ( - backend_config.max_concurrent_queries) + deployment_config.max_concurrent_queries) self._graceful_shutdown_timeout_s = ( - backend_config.graceful_shutdown_timeout_s) + deployment_config.graceful_shutdown_timeout_s) except Exception: return ReplicaStartupStatus.FAILED, None @@ -726,11 +726,11 @@ def _set_backend_goal(self, backend_info: Optional[BackendInfo]) -> None: if backend_info is not None: self._target_info = backend_info - self._target_replicas = backend_info.backend_config.num_replicas + self._target_replicas = backend_info.deployment_config.num_replicas self._target_version = DeploymentVersion( backend_info.version, - user_config=backend_info.backend_config.user_config) + user_config=backend_info.deployment_config.user_config) else: self._target_replicas = 0 @@ -746,7 +746,7 @@ def deploy(self, backend_info: BackendInfo) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version and BackendConfig, + If the backend already exists with the same version and config, this is a no-op and returns the GoalId corresponding to the existing update if there is one. @@ -760,7 +760,8 @@ def deploy(self, # Redeploying should not reset the deployment's start time. backend_info.start_time_ms = existing_info.start_time_ms - if (existing_info.backend_config == backend_info.backend_config + if (existing_info.deployment_config == + backend_info.deployment_config and backend_info.version is not None and existing_info.version == backend_info.version): return self._curr_goal, False @@ -1291,19 +1292,20 @@ def get_running_replica_infos( return replicas - def get_backend_configs(self, - filter_tag: Optional[BackendTag] = None, - include_deleted: Optional[bool] = False - ) -> Dict[BackendTag, BackendConfig]: - configs: Dict[BackendTag, BackendConfig] = {} + def get_deployment_configs(self, + filter_tag: Optional[BackendTag] = None, + include_deleted: Optional[bool] = False + ) -> Dict[BackendTag, DeploymentConfig]: + configs: Dict[BackendTag, DeploymentConfig] = {} for backend_tag, backend_state in self._backend_states.items(): if filter_tag is None or backend_tag == filter_tag: - configs[backend_tag] = backend_state.target_info.backend_config + configs[ + backend_tag] = backend_state.target_info.deployment_config if include_deleted: for backend_tag, info in self._deleted_backend_metadata.items(): if filter_tag is None or backend_tag == filter_tag: - configs[backend_tag] = info.backend_config + configs[backend_tag] = info.deployment_config return configs @@ -1322,7 +1324,7 @@ def deploy_backend(self, backend_tag: BackendTag, backend_info: BackendInfo ) -> Tuple[Optional[GoalId], bool]: """Deploy the backend. - If the backend already exists with the same version and BackendConfig, + If the backend already exists with the same version and config, this is a no-op and returns the GoalId corresponding to the existing update if there is one. diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index 3c236a668d26..eb0a5dead6ea 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -5,7 +5,7 @@ from uuid import UUID from ray.actor import ActorClass, ActorHandle -from ray.serve.config import BackendConfig, ReplicaConfig +from ray.serve.config import DeploymentConfig, ReplicaConfig from ray.serve.autoscaling_policy import AutoscalingPolicy BackendTag = str @@ -23,7 +23,7 @@ class EndpointInfo: class BackendInfo: def __init__(self, - backend_config: BackendConfig, + deployment_config: DeploymentConfig, replica_config: ReplicaConfig, start_time_ms: int, actor_def: Optional[ActorClass] = None, @@ -31,7 +31,7 @@ def __init__(self, deployer_job_id: "Optional[ray._raylet.JobID]" = None, end_time_ms: Optional[int] = None, autoscaling_policy: Optional[AutoscalingPolicy] = None): - self.backend_config = backend_config + self.deployment_config = deployment_config self.replica_config = replica_config # The time when .deploy() was first called for this deployment. self.start_time_ms = start_time_ms diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 7f010d3ff37e..60a1e1344dd8 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -7,10 +7,10 @@ from google.protobuf.json_format import MessageToDict from pydantic import BaseModel, NonNegativeFloat, PositiveInt, validator from ray.serve.constants import DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT -from ray.serve.generated.serve_pb2 import (BackendConfig as BackendConfigProto, - AutoscalingConfig as - AutoscalingConfigProto) -from ray.serve.generated.serve_pb2 import BackendLanguage +from ray.serve.generated.serve_pb2 import ( + DeploymentConfig as DeploymentConfigProto, AutoscalingConfig as + AutoscalingConfigProto) +from ray.serve.generated.serve_pb2 import DeploymentLanguage from ray import cloudpickle as cloudpickle @@ -56,21 +56,21 @@ class AutoscalingConfig(BaseModel): # TODO(architkulkarni): Add pydantic validation. E.g. max_replicas>=min -class BackendConfig(BaseModel): - """Configuration options for a backend, to be set by the user. +class DeploymentConfig(BaseModel): + """Configuration options for a deployment, to be set by the user. Args: num_replicas (Optional[int]): The number of processes to start up that - will handle requests to this backend. Defaults to 1. + will handle requests to this deployment. Defaults to 1. max_concurrent_queries (Optional[int]): The maximum number of queries - that will be sent to a replica of this backend without receiving a - response. Defaults to 100. + that will be sent to a replica of this deployment without receiving + a response. Defaults to 100. user_config (Optional[Any]): Arguments to pass to the reconfigure - method of the backend. The reconfigure method is called if + method of the deployment. The reconfigure method is called if user_config is not None. graceful_shutdown_wait_loop_s (Optional[float]): Duration - that backend workers will wait until there is no more work to be - done before shutting down. Defaults to 2s. + that deployment replicas will wait until there is no more work to + be done before shutting down. Defaults to 2s. graceful_shutdown_timeout_s (Optional[float]): Controller waits for this duration to forcefully kill the replica for shutdown. Defaults to 20s. @@ -107,15 +107,15 @@ def to_proto_bytes(self): if data.get("autoscaling_config"): data["autoscaling_config"] = AutoscalingConfigProto( **data["autoscaling_config"]) - return BackendConfigProto( + return DeploymentConfigProto( is_cross_language=False, - backend_language=BackendLanguage.PYTHON, + deployment_language=DeploymentLanguage.PYTHON, **data, ).SerializeToString() @classmethod def from_proto_bytes(cls, proto_bytes: bytes): - proto = BackendConfigProto.FromString(proto_bytes) + proto = DeploymentConfigProto.FromString(proto_bytes) data = MessageToDict( proto, including_default_value_fields=True, @@ -131,37 +131,36 @@ def from_proto_bytes(cls, proto_bytes: bytes): # Delete fields which are only used in protobuf, not in Python. del data["is_cross_language"] - del data["backend_language"] + del data["deployment_language"] return cls(**data) class ReplicaConfig: def __init__(self, - backend_def: Callable, + deployment_def: Callable, init_args: Optional[Tuple[Any]] = None, init_kwargs: Optional[Dict[Any, Any]] = None, ray_actor_options=None): - # Validate that backend_def is an import path, function, or class. - if isinstance(backend_def, str): - self.func_or_class_name = backend_def - pass - elif inspect.isfunction(backend_def): - self.func_or_class_name = backend_def.__name__ + # Validate that deployment_def is an import path, function, or class. + if isinstance(deployment_def, str): + self.func_or_class_name = deployment_def + elif inspect.isfunction(deployment_def): + self.func_or_class_name = deployment_def.__name__ if init_args: raise ValueError( - "init_args not supported for function backend.") + "init_args not supported for function deployments.") if init_kwargs: raise ValueError( - "init_kwargs not supported for function backend.") - elif inspect.isclass(backend_def): - self.func_or_class_name = backend_def.__name__ + "init_kwargs not supported for function deployments.") + elif inspect.isclass(deployment_def): + self.func_or_class_name = deployment_def.__name__ else: raise TypeError( - "Backend must be an import path, function or class, it is {}.". - format(type(backend_def))) + "Deployment must be a function or class, it is {}.".format( + type(deployment_def))) - self.serialized_backend_def = cloudpickle.dumps(backend_def) + self.serialized_deployment_def = cloudpickle.dumps(deployment_def) self.init_args = init_args if init_args is not None else () self.init_kwargs = init_kwargs if init_kwargs is not None else {} if ray_actor_options is None: @@ -175,7 +174,7 @@ def __init__(self, def _validate(self): if "placement_group" in self.ray_actor_options: - raise ValueError("Providing placement_group for backend actors " + raise ValueError("Providing placement_group for deployment actors " "is not currently supported.") if not isinstance(self.ray_actor_options, dict): diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e1dd906a398a..cbebddafa0c4 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -20,7 +20,7 @@ NodeId, RunningReplicaInfo, ) -from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig +from ray.serve.config import DeploymentConfig, HTTPOptions, ReplicaConfig from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY from ray.serve.endpoint_state import EndpointState from ray.serve.http_state import HTTPState @@ -142,7 +142,7 @@ def autoscale(self) -> None: """Updates autoscaling deployments with calculated num_replicas.""" for deployment_name, (backend_info, route_prefix) in self.list_deployments().items(): - backend_config = backend_info.backend_config + deployment_config = backend_info.deployment_config autoscaling_policy = backend_info.autoscaling_policy if autoscaling_policy is None: @@ -166,16 +166,16 @@ def autoscale(self) -> None: if len(current_num_ongoing_requests) == 0: continue - new_backend_config = backend_config.copy() + new_deployment_config = deployment_config.copy() decision_num_replicas = ( autoscaling_policy.get_decision_num_replicas( current_num_ongoing_requests=current_num_ongoing_requests, - curr_target_num_replicas=backend_config.num_replicas)) - new_backend_config.num_replicas = decision_num_replicas + curr_target_num_replicas=deployment_config.num_replicas)) + new_deployment_config.num_replicas = decision_num_replicas new_backend_info = copy(backend_info) - new_backend_info.backend_config = new_backend_config + new_backend_info.deployment_config = new_deployment_config goal_id, updating = self.backend_state_manager.deploy_backend( deployment_name, new_backend_info) @@ -275,7 +275,7 @@ async def shutdown(self) -> List[GoalId]: def deploy(self, name: str, - backend_config_proto_bytes: bytes, + deployment_config_proto_bytes: bytes, replica_config: ReplicaConfig, version: Optional[str], prev_version: Optional[str], @@ -285,8 +285,8 @@ def deploy(self, if route_prefix is not None: assert route_prefix.startswith("/") - backend_config = BackendConfig.from_proto_bytes( - backend_config_proto_bytes) + deployment_config = DeploymentConfig.from_proto_bytes( + deployment_config_proto_bytes) if prev_version is not None: existing_backend_info = self.backend_state_manager.get_backend( @@ -301,10 +301,10 @@ def deploy(self, "does not match with the existing " f"version '{existing_backend_info.version}'.") - autoscaling_config = backend_config.autoscaling_config + autoscaling_config = deployment_config.autoscaling_config if autoscaling_config is not None: # TODO: is this the desired behaviour? Should this be a setting? - backend_config.num_replicas = autoscaling_config.min_replicas + deployment_config.num_replicas = autoscaling_config.min_replicas autoscaling_policy = BasicAutoscalingPolicy(autoscaling_config) else: @@ -312,10 +312,10 @@ def deploy(self, backend_info = BackendInfo( actor_def=ray.remote( - create_replica_wrapper(name, - replica_config.serialized_backend_def)), + create_replica_wrapper( + name, replica_config.serialized_deployment_def)), version=version, - backend_config=backend_config, + deployment_config=deployment_config, replica_config=replica_config, deployer_job_id=deployer_job_id, start_time_ms=int(time.time() * 1000), @@ -373,6 +373,6 @@ def list_deployments(self, include_deleted: Optional[bool] = False name: (self.backend_state_manager.get_backend( name, include_deleted=include_deleted), self.endpoint_state.get_endpoint_route(name)) - for name in self.backend_state_manager.get_backend_configs( + for name in self.backend_state_manager.get_deployment_configs( include_deleted=include_deleted) } diff --git a/python/ray/serve/replica.py b/python/ray/serve/replica.py index a85ed390b71c..45db316106a7 100644 --- a/python/ray/serve/replica.py +++ b/python/ray/serve/replica.py @@ -15,7 +15,7 @@ from ray.serve.autoscaling_metrics import start_metrics_pusher from ray.serve.common import BackendTag, ReplicaTag -from ray.serve.config import BackendConfig +from ray.serve.config import DeploymentConfig from ray.serve.http_util import ASGIHTTPSender from ray.serve.utils import parse_request_item, _get_logger from ray.serve.exceptions import RayServeException @@ -31,30 +31,30 @@ logger = _get_logger() -def create_replica_wrapper(name: str, serialized_backend_def: bytes): +def create_replica_wrapper(name: str, serialized_deployment_def: bytes): """Creates a replica class wrapping the provided function or class. This approach is picked over inheritance to avoid conflict between user provided class and the RayServeReplica class. """ - serialized_backend_def = serialized_backend_def + serialized_deployment_def = serialized_deployment_def # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedReplica(object): async def __init__(self, backend_tag, replica_tag, init_args, - init_kwargs, backend_config_proto_bytes: bytes, + init_kwargs, deployment_config_proto_bytes: bytes, version: DeploymentVersion, controller_name: str, detached: bool): - backend = cloudpickle.loads(serialized_backend_def) - backend_config = BackendConfig.from_proto_bytes( - backend_config_proto_bytes) + backend = cloudpickle.loads(serialized_deployment_def) + deployment_config = DeploymentConfig.from_proto_bytes( + deployment_config_proto_bytes) if inspect.isfunction(backend): is_function = True elif inspect.isclass(backend): is_function = False else: - assert False, ("backend_def must be function, class, or " + assert False, ("deployment_def must be function, class, or " "corresponding import path.") # Set the controller name so that serve.connect() in the user's @@ -85,10 +85,10 @@ async def __init__(self, backend_tag, replica_tag, init_args, detached) controller_handle = ray.get_actor( controller_name, namespace=controller_namespace) - self.backend = RayServeReplica(_callable, backend_tag, replica_tag, - backend_config, - backend_config.user_config, version, - is_function, controller_handle) + self.backend = RayServeReplica( + _callable, backend_tag, replica_tag, deployment_config, + deployment_config.user_config, version, is_function, + controller_handle) # asyncio.Event used to signal that the replica is shutting down. self.shutdown_event = asyncio.Event() @@ -109,14 +109,14 @@ async def handle_request( return await self.backend.handle_request(query) async def reconfigure(self, user_config: Optional[Any] = None - ) -> Tuple[BackendConfig, DeploymentVersion]: + ) -> Tuple[DeploymentConfig, DeploymentVersion]: if user_config is not None: await self.backend.reconfigure(user_config) return self.get_metadata() - def get_metadata(self) -> Tuple[BackendConfig, DeploymentVersion]: - return self.backend.backend_config, self.backend.version + def get_metadata(self) -> Tuple[DeploymentConfig, DeploymentVersion]: + return self.backend.deployment_config, self.backend.version async def prepare_for_shutdown(self): self.shutdown_event.set() @@ -146,10 +146,10 @@ class RayServeReplica: """Handles requests with the provided callable.""" def __init__(self, _callable: Callable, backend_tag: BackendTag, - replica_tag: ReplicaTag, backend_config: BackendConfig, + replica_tag: ReplicaTag, deployment_config: DeploymentConfig, user_config: Any, version: DeploymentVersion, is_function: bool, controller_handle: ActorHandle) -> None: - self.backend_config = backend_config + self.deployment_config = deployment_config self.backend_tag = backend_tag self.replica_tag = replica_tag self.callable = _callable @@ -211,10 +211,10 @@ def __init__(self, _callable: Callable, backend_tag: BackendTag, self.restart_counter.inc() self._shutdown_wait_loop_s = ( - backend_config.graceful_shutdown_wait_loop_s) + deployment_config.graceful_shutdown_wait_loop_s) - if backend_config.autoscaling_config: - config = backend_config.autoscaling_config + if deployment_config.autoscaling_config: + config = deployment_config.autoscaling_config start_metrics_pusher( interval_s=config.metrics_interval_s, collection_callback=self._collect_autoscaling_metrics, @@ -319,7 +319,8 @@ async def reconfigure(self, user_config: Any): self.version = DeploymentVersion( self.version.code_version, user_config=user_config) if self.is_function: - raise ValueError("backend_def must be a class to use user_config") + raise ValueError( + "deployment_def must be a class to use user_config") elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD): raise RayServeException("user_config specified but backend " + self.backend_tag + " missing " + diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index 74a170a6f1c6..687382945249 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -8,7 +8,7 @@ from ray.actor import ActorHandle from ray.serve.common import ( - BackendConfig, + DeploymentConfig, BackendInfo, BackendTag, ReplicaConfig, @@ -130,7 +130,7 @@ def available_resources(self) -> Dict[str, float]: def graceful_stop(self) -> None: assert self.started self.stopped = True - return self.backend_info.backend_config.graceful_shutdown_timeout_s + return self.backend_info.deployment_config.graceful_shutdown_timeout_s def check_stopped(self) -> bool: return self.done_stopping @@ -154,7 +154,7 @@ def backend_info(version: Optional[str] = None, actor_def=None, version=version, start_time_ms=0, - backend_config=BackendConfig( + deployment_config=DeploymentConfig( num_replicas=num_replicas, user_config=user_config, **config_opts), replica_config=ReplicaConfig(lambda x: x)) @@ -163,7 +163,8 @@ def backend_info(version: Optional[str] = None, else: code_version = get_random_letters() - version = DeploymentVersion(code_version, info.backend_config.user_config) + version = DeploymentVersion(code_version, + info.deployment_config.user_config) return info, version diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 505147055e34..f8e96487434d 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -1,29 +1,29 @@ import pytest from pydantic import ValidationError -from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions, +from ray.serve.config import (DeploymentConfig, DeploymentMode, HTTPOptions, ReplicaConfig) from ray.serve.config import AutoscalingConfig -def test_backend_config_validation(): +def test_deployment_config_validation(): # Test unknown key. with pytest.raises(ValidationError): - BackendConfig(unknown_key=-1) + DeploymentConfig(unknown_key=-1) # Test num_replicas validation. - BackendConfig(num_replicas=1) + DeploymentConfig(num_replicas=1) with pytest.raises(ValidationError, match="type_error"): - BackendConfig(num_replicas="hello") + DeploymentConfig(num_replicas="hello") with pytest.raises(ValidationError, match="value_error"): - BackendConfig(num_replicas=-1) + DeploymentConfig(num_replicas=-1) # Test dynamic default for max_concurrent_queries. - assert BackendConfig().max_concurrent_queries == 100 + assert DeploymentConfig().max_concurrent_queries == 100 -def test_backend_config_update(): - b = BackendConfig(num_replicas=1, max_concurrent_queries=1) +def test_deployment_config_update(): + b = DeploymentConfig(num_replicas=1, max_concurrent_queries=1) # Test updating a key works. b.num_replicas = 2 @@ -108,18 +108,18 @@ def test_http_options(): def test_with_proto(): # Test roundtrip - config = BackendConfig(num_replicas=100, max_concurrent_queries=16) - assert config == BackendConfig.from_proto_bytes(config.to_proto_bytes()) + config = DeploymentConfig(num_replicas=100, max_concurrent_queries=16) + assert config == DeploymentConfig.from_proto_bytes(config.to_proto_bytes()) # Test user_config object - config = BackendConfig(user_config={"python": ("native", ["objects"])}) - assert config == BackendConfig.from_proto_bytes(config.to_proto_bytes()) + config = DeploymentConfig(user_config={"python": ("native", ["objects"])}) + assert config == DeploymentConfig.from_proto_bytes(config.to_proto_bytes()) def test_zero_default_proto(): # Test that options set to zero (protobuf default value) still retain their # original value after being serialized and deserialized. - config = BackendConfig( + config = DeploymentConfig( autoscaling_config={ "min_replicas": 1, "max_replicas": 2, @@ -127,7 +127,7 @@ def test_zero_default_proto(): "downscale_delay_s": 0 }) serialized_config = config.to_proto_bytes() - deserialized_config = BackendConfig.from_proto_bytes(serialized_config) + deserialized_config = DeploymentConfig.from_proto_bytes(serialized_config) new_delay_s = deserialized_config.autoscaling_config.downscale_delay_s assert new_delay_s == 0 diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index 2cadc2128dfb..20f451f1129d 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -51,40 +51,40 @@ message AutoscalingConfig { double upscale_delay_s = 8; } -// Configuration options for a backend, to be set by the user. -message BackendConfig { - // The number of processes to start up that will handle requests to this backend. +// Configuration options for a deployment, to be set by the user. +message DeploymentConfig { + // The number of processes to start up that will handle requests to this deployment. // Defaults to 1. int32 num_replicas = 1; - // The maximum number of queries that will be sent to a replica of this backend without - // receiving a response. Defaults to 100. + // The maximum number of queries that will be sent to a replica of this deployment + // without receiving a response. Defaults to 100. int32 max_concurrent_queries = 2; - // Arguments to pass to the reconfigure method of the backend. The reconfigure method is - // called if user_config is not None. + // Arguments to pass to the reconfigure method of the deployment. The reconfigure method + // is called if user_config is not None. bytes user_config = 3; - // Duration that backend workers will wait until there is no more work to be done before - // shutting down. Defaults to 2s. + // Duration that deployment replicas will wait until there is no more work to be done + // before shutting down. Defaults to 2s. double graceful_shutdown_wait_loop_s = 4; // Controller waits for this duration to forcefully kill the replica for shutdown. // Defaults to 20s. double graceful_shutdown_timeout_s = 5; - // Is the construction of backend is cross language? + // Is the construction of deployment is cross language? bool is_cross_language = 6; - // The backend's programming language. - BackendLanguage backend_language = 7; + // The deployment's programming language. + DeploymentLanguage deployment_language = 7; - // The backend's autoscaling configuration. + // The deployment's autoscaling configuration. AutoscalingConfig autoscaling_config = 8; } // Backend language. -enum BackendLanguage { +enum DeploymentLanguage { PYTHON = 0; JAVA = 1; } From a03c4363b5323aab16e1631de83586393d0c7983 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Mon, 1 Nov 2021 10:25:43 -0700 Subject: [PATCH 03/15] [Collective] Allow send/recv partial tensors in Send/Recv primitives (#19921) --- python/ray/util/collective/collective.py | 16 ++++++++++++++-- .../collective_group/nccl_collective_group.py | 6 ++++-- python/ray/util/collective/types.py | 2 ++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 4cebe9eefcfd..16d1fea7b249 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -544,7 +544,8 @@ def send(tensor, dst_rank: int, group_name: str = "default"): def send_multigpu(tensor, dst_rank: int, dst_gpu_index: int, - group_name: str = "default"): + group_name: str = "default", + n_elements: int = 0): """Send a tensor to a remote GPU synchronously. The function asssume each process owns >1 GPUs, and the sender @@ -555,6 +556,8 @@ def send_multigpu(tensor, dst_rank (int): the rank of the destination process. dst_gpu_index (int): the destination gpu index. group_name (str): the name of the collective group. + n_elements (int): if specified, send the next n elements + from the starting address of tensor. Returns: None @@ -567,9 +570,13 @@ def send_multigpu(tensor, if dst_rank == g.rank: raise RuntimeError("The dst_rank '{}' is self. Considering " "doing GPU to GPU memcpy instead?".format(dst_rank)) + if n_elements < 0: + raise RuntimeError( + "The n_elements '{}' should >= 0.".format(n_elements)) opts = types.SendOptions() opts.dst_rank = dst_rank opts.dst_gpu_index = dst_gpu_index + opts.n_elements = n_elements g.send([tensor], opts) @@ -598,7 +605,8 @@ def recv(tensor, src_rank: int, group_name: str = "default"): def recv_multigpu(tensor, src_rank: int, src_gpu_index: int, - group_name: str = "default"): + group_name: str = "default", + n_elements: int = 0): """Receive a tensor from a remote GPU synchronously. The function asssume each process owns >1 GPUs, and the sender @@ -621,9 +629,13 @@ def recv_multigpu(tensor, if src_rank == g.rank: raise RuntimeError("The dst_rank '{}' is self. Considering " "doing GPU to GPU memcpy instead?".format(src_rank)) + if n_elements < 0: + raise RuntimeError( + "The n_elements '{}' should be >= 0.".format(n_elements)) opts = types.RecvOptions() opts.src_rank = src_rank opts.src_gpu_index = src_gpu_index + opts.n_elements = n_elements g.recv([tensor], opts) diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py index a73dc9526e7a..6825ed0813da 100644 --- a/python/ray/util/collective/collective_group/nccl_collective_group.py +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -348,7 +348,8 @@ def send(self, tensors, send_options=SendOptions()): def p2p_fn(tensor, comm, stream, peer): comm.send( - nccl_util.get_tensor_ptr(tensor), + nccl_util.get_tensor_ptr(tensor), send_options.n_elements + if send_options.n_elements > 0 else nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) @@ -368,7 +369,8 @@ def recv(self, tensors, recv_options=RecvOptions()): def p2p_fn(tensor, comm, stream, peer): comm.recv( - nccl_util.get_tensor_ptr(tensor), + nccl_util.get_tensor_ptr(tensor), recv_options.n_elements + if recv_options.n_elements > 0 else nccl_util.get_tensor_n_elements(tensor), nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr) diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index 6aff71b3a53d..b949177657c7 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -101,6 +101,7 @@ class ReduceScatterOptions: class SendOptions: dst_rank = 0 dst_gpu_index = 0 + n_elements = 0 timeout_ms = unset_timeout_ms @@ -108,4 +109,5 @@ class SendOptions: class RecvOptions: src_rank = 0 src_gpu_index = 0 + n_elements = 0 unset_timeout_ms = unset_timeout_ms From 1803ca13b62649b032b9db8bc59adbcd3bb93078 Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Mon, 1 Nov 2021 10:26:04 -0700 Subject: [PATCH 04/15] Adding release logs for 1.8.0. (#19867) --- .../1.8.0/benchmarks/many_actors.json | 10 ++ .../1.8.0/benchmarks/many_nodes.json | 10 ++ .../1.8.0/benchmarks/many_pgs.json | 10 ++ .../1.8.0/benchmarks/many_tasks.json | 10 ++ .../release_logs/1.8.0/microbenchmark.json | 134 ++++++++++++++++++ .../1.8.0/scalability/object_store.json | 10 ++ .../1.8.0/scalability/single_node.json | 17 +++ .../stress_tests/stress_test_dead_actors.json | 11 ++ .../stress_tests/stress_test_many_tasks.json | 19 +++ .../stress_test_placement_group.json | 9 ++ 10 files changed, 240 insertions(+) create mode 100644 release/release_logs/1.8.0/benchmarks/many_actors.json create mode 100644 release/release_logs/1.8.0/benchmarks/many_nodes.json create mode 100644 release/release_logs/1.8.0/benchmarks/many_pgs.json create mode 100644 release/release_logs/1.8.0/benchmarks/many_tasks.json create mode 100644 release/release_logs/1.8.0/microbenchmark.json create mode 100644 release/release_logs/1.8.0/scalability/object_store.json create mode 100644 release/release_logs/1.8.0/scalability/single_node.json create mode 100644 release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json create mode 100644 release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json create mode 100644 release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json diff --git a/release/release_logs/1.8.0/benchmarks/many_actors.json b/release/release_logs/1.8.0/benchmarks/many_actors.json new file mode 100644 index 000000000000..1b9e2b210fae --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_actors.json @@ -0,0 +1,10 @@ +{ + "actors_per_second":502.27667887403527, + "num_actors":10000, + "time":19.909345626831055, + "success":"1", + "_runtime":34.931002140045166, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_NtMW8qrs2wfGbb1DSfhkXpa7", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/benchmarks/many_nodes.json b/release/release_logs/1.8.0/benchmarks/many_nodes.json new file mode 100644 index 000000000000..4272a3c2aadb --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_nodes.json @@ -0,0 +1,10 @@ +{ + "tasks_per_second":3.1463954810569446, + "num_tasks":1000, + "time":617.8240008354187, + "success":"1", + "_runtime":627.3966097831726, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_M5RJTBWv4HPcVW4LJBpQ3QbU", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/benchmarks/many_pgs.json b/release/release_logs/1.8.0/benchmarks/many_pgs.json new file mode 100644 index 000000000000..4c834fb150d7 --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_pgs.json @@ -0,0 +1,10 @@ +{ + "pgs_per_second":18.38968677640112, + "num_pgs":1000, + "time":54.378305196762085, + "success":"1", + "_runtime":70.17751049995422, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_NTMUAfRBeFAnGucHNvxk9SCe", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/benchmarks/many_tasks.json b/release/release_logs/1.8.0/benchmarks/many_tasks.json new file mode 100644 index 000000000000..70c9e1af8a41 --- /dev/null +++ b/release/release_logs/1.8.0/benchmarks/many_tasks.json @@ -0,0 +1,10 @@ +{ + "tasks_per_second":27.380515768078205, + "num_tasks":10000, + "time":665.2232151031494, + "success":"1", + "_runtime":676.2212898731232, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_DRNyaBrjPz92eS7XGH6s86J2", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/microbenchmark.json b/release/release_logs/1.8.0/microbenchmark.json new file mode 100644 index 000000000000..1d8c2cc53d17 --- /dev/null +++ b/release/release_logs/1.8.0/microbenchmark.json @@ -0,0 +1,134 @@ +{ + "single_client_get_calls":[ + 30940.134096399626, + 317.0832533971391 + ], + "single_client_put_calls":[ + 48943.191804559654, + 219.49873322500986 + ], + "multi_client_put_calls":[ + 196309.52527610504, + 3737.9627941064145 + ], + "single_client_get_calls_Plasma_Store":[ + 6846.037019609748, + 22.8467171817285 + ], + "single_client_put_calls_Plasma_Store":[ + 6460.924069006876, + 80.8868946635643 + ], + "multi_client_put_calls_Plasma_Store":[ + 10041.514280594329, + 156.65023333195015 + ], + "single_client_put_gigabytes":[ + 19.409071238092803, + 5.346465780702707 + ], + "single_client_tasks_and_get_batch":[ + 14.009166496750346, + 0.21491106195168685 + ], + "multi_client_put_gigabytes":[ + 35.23946199953088, + 0.7596648393204519 + ], + "single_client_get_object_containing_10k_refs":[ + 13.444237954966656, + 0.23494256884483813 + ], + "single_client_tasks_sync":[ + 1635.8932418252052, + 27.31879448903416 + ], + "single_client_tasks_async":[ + 13818.97086807901, + 296.5632852961327 + ], + "multi_client_tasks_async":[ + 39833.383674172335, + 2733.4624633691615 + ], + "1_1_actor_calls_sync":[ + 2642.40420304774, + 41.91740557372103 + ], + "1_1_actor_calls_async":[ + 6926.955725014122, + 41.74140409894753 + ], + "1_1_actor_calls_concurrent":[ + 5676.325754099567, + 151.98507185863667 + ], + "1_n_actor_calls_async":[ + 14243.648501790683, + 308.2488257960866 + ], + "n_n_actor_calls_async":[ + 44238.1477273468, + 3957.0334984514498 + ], + "n_n_actor_calls_with_arg_async":[ + 3338.0491560395135, + 20.518610734699735 + ], + "1_1_async_actor_calls_sync":[ + 1960.6729229123905, + 19.399824989325484 + ], + "1_1_async_actor_calls_async":[ + 4015.4271724213463, + 325.935225727618 + ], + "1_1_async_actor_calls_with_args_async":[ + 2844.8051074858054, + 170.12415992293433 + ], + "1_n_async_actor_calls_async":[ + 15039.391685457813, + 680.5919632322198 + ], + "n_n_async_actor_calls_async":[ + 39514.27962589182, + 1501.8795564889163 + ], + "client__get_calls":[ + 1975.254953927033, + 8.735788531735636 + ], + "client__put_calls":[ + 1098.0277773219968, + 7.155561544388648 + ], + "client__put_gigabytes":[ + 0.13727111000071365, + 0.0026377327384535312 + ], + "client__tasks_and_put_batch":[ + 69321.64862718538, + 1767.6000938242169 + ], + "client__1_1_actor_calls_sync":[ + 564.4912609733829, + 30.465268060749132 + ], + "client__1_1_actor_calls_async":[ + 943.8060173510352, + 17.7532101628873 + ], + "client__1_1_actor_calls_concurrent":[ + 949.6015148062681, + 12.69568974417014 + ], + "client__tasks_and_get_batch":[ + 0.8703009855123071, + 0.013280271782555824 + ], + "_runtime":533.6895747184753, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_DFQpftV7rrj7THGByj5N9ELj", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/scalability/object_store.json b/release/release_logs/1.8.0/scalability/object_store.json new file mode 100644 index 000000000000..4f4847311731 --- /dev/null +++ b/release/release_logs/1.8.0/scalability/object_store.json @@ -0,0 +1,10 @@ +{ + "broadcast_time":1478.5116119949998, + "object_size":1073741824, + "num_nodes":50, + "success":"1", + "_runtime":1500.23917222023, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_4HtGPbfxkKyjpLYxs4Xc7Quf", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/scalability/single_node.json b/release/release_logs/1.8.0/scalability/single_node.json new file mode 100644 index 000000000000..5f3368320acf --- /dev/null +++ b/release/release_logs/1.8.0/scalability/single_node.json @@ -0,0 +1,17 @@ +{ + "args_time":17.633971685999995, + "num_args":10000, + "returns_time":5.793101844999967, + "num_returns":3000, + "get_time":28.41037361399998, + "num_get_args":10000, + "queued_time":155.58971768599997, + "num_queued":1000000, + "large_object_time":289.177471706, + "large_object_size":107374182400, + "success":"1", + "_runtime":547.9792876243591, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_BiK1hJqWQaKY1kPtzWZtrk7z", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json b/release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json new file mode 100644 index 000000000000..7b0d4fe4ecc8 --- /dev/null +++ b/release/release_logs/1.8.0/stress_tests/stress_test_dead_actors.json @@ -0,0 +1,11 @@ +{ + "success":1, + "total_time":131.1956226825714, + "avg_iteration_time":1.3119536685943602, + "max_iteration_time":3.662248373031616, + "min_iteration_time":0.08778810501098633, + "_runtime":4927.166862249374, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_bEKXDxxZwwg3p4cYEiWA82gR", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json b/release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json new file mode 100644 index 000000000000..3a6c9c74e9b1 --- /dev/null +++ b/release/release_logs/1.8.0/stress_tests/stress_test_many_tasks.json @@ -0,0 +1,19 @@ +{ + "success":1, + "stage_0_time":5.756206750869751, + "stage_1_time":190.08577489852905, + "stage_1_avg_iteration_time":19.008567762374877, + "stage_1_max_iteration_time":19.4687020778656, + "stage_1_min_iteration_time":18.608890295028687, + "stage_2_time":246.9729506969452, + "stage_2_avg_iteration_time":49.39428896903992, + "stage_2_max_iteration_time":50.23058724403381, + "stage_2_min_iteration_time":47.09399747848511, + "stage_3_creation_time":0.05593752861022949, + "stage_3_time":1843.905479669571, + "stage_4_spread":3.2320969134286446, + "_runtime":4446.973560810089, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_wv1Ch4n2WCKNDLzUJwySyYJ2", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} diff --git a/release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json b/release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json new file mode 100644 index 000000000000..6b22a4231ff3 --- /dev/null +++ b/release/release_logs/1.8.0/stress_tests/stress_test_placement_group.json @@ -0,0 +1,9 @@ +{ + "success":1, + "avg_pg_create_time_ms":0.9178227477476738, + "avg_pg_remove_time_ms":3.5015487627617348, + "_runtime":381.0450084209442, + "_session_url":"https://beta.anyscale.com/o/anyscale-internal/projects/prj_2xR6uT6t7jJuu1aCwWMsle/clusters/ses_hMCMNudYEmmqv987qYbadjrQ", + "_commit_url":"https://s3-us-west-2.amazonaws.com/ray-wheels/releases/1.8.0/d28be7e0c555808b38a5ce2da8d6c48d5162ce12/ray-1.8.0-cp37-cp37m-manylinux2014_x86_64.whl", + "_stable":true +} From 80fb3f10ae59361845c6a53ee82847c3ec88d52a Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Mon, 1 Nov 2021 11:44:59 -0700 Subject: [PATCH 05/15] [ci] Script for building M1 wheels (#19925) This PR includes a script for building wheels for Macs with M1 processors. It roughly follows the pattern of the other scripts with a few differences. Manually installs nvm Uses miniforge conda to install python/pip instead of python foundation .pkgs Doesn't pin numpy (we probably shouldn't be pinning it in the other scripts either...) Commit detection falls back to git instead of erroring All of these changes were made so that the script works on a laptop, which comes with a subset of the dependencies that the x86 buildkite image comes with. --- python/build-wheel-macos-arm64.sh | 92 +++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 python/build-wheel-macos-arm64.sh diff --git a/python/build-wheel-macos-arm64.sh b/python/build-wheel-macos-arm64.sh new file mode 100644 index 000000000000..097834570469 --- /dev/null +++ b/python/build-wheel-macos-arm64.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# Cause the script to exit if a single command fails. +set -e + +# Show explicitly which commands are currently running. +set -x + +DOWNLOAD_DIR=python_downloads + +NODE_VERSION="14" +PY_VERSIONS=("3.8.2" + "3.9.1") +PY_MMS=("3.8" + "3.9") + + +if [[ -n "${SKIP_DEP_RES}" ]]; then + ./ci/travis/install-bazel.sh + + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.0/install.sh | bash + curl -o- https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh | bash + source ~/.bash_profile + conda init bash + source ~/.bash_profile + + # Use the latest version of Node.js in order to build the dashboard. + source "$HOME"/.nvm/nvm.sh + nvm install $NODE_VERSION + nvm use $NODE_VERSION +fi + +# Build the dashboard so its static assets can be included in the wheel. +pushd python/ray/dashboard/client + npm ci + npm run build +popd + +mkdir -p .whl + +for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do + PY_MM=${PY_MMS[i]} + CONDA_ENV_NAME="p$PY_MM" + + # The -f flag is passed twice to also run git clean in the arrow subdirectory. + # The -d flag removes directories. The -x flag ignores the .gitignore file, + # and the -e flag ensures that we don't remove the .whl directory. + git clean -f -f -x -d -e .whl -e $DOWNLOAD_DIR -e python/ray/dashboard/client -e dashboard/client + + + # Install python using conda. This should be easier to produce consistent results in buildkite and locally. + source ~/.bash_profile + conda create -y -n "$CONDA_ENV_NAME" + conda activate "$CONDA_ENV_NAME" + conda remove -y python || true + conda install -y python="$PY_MM" + + # NOTE: We expect conda to set the PATH properly. + PIP_CMD=pip + PYTHON_EXE=python + + $PIP_CMD install --upgrade pip + + if [ -z "${TRAVIS_COMMIT}" ]; then + TRAVIS_COMMIT=${BUILDKITE_COMMIT} + fi + + pushd python + # Setuptools on CentOS is too old to install arrow 0.9.0, therefore we upgrade. + $PIP_CMD install --upgrade setuptools + $PIP_CMD install -q cython==0.29.15 + # Install wheel to avoid the error "invalid command 'bdist_wheel'". + $PIP_CMD install -q wheel + # Set the commit SHA in __init__.py. + if [ -n "$TRAVIS_COMMIT" ]; then + echo "TRAVIS_COMMIT variable detected. ray.__commit__ will be set to $TRAVIS_COMMIT" + else + echo "TRAVIS_COMMIT variable is not set, getting the current commit from git." + TRAVIS_COMMIT=$(git rev-parse HEAD) + fi + + sed -i .bak "s/{{RAY_COMMIT_SHA}}/$TRAVIS_COMMIT/g" ray/__init__.py && rm ray/__init__.py.bak + + # Add the correct Python to the path and build the wheel. This is only + # needed so that the installation finds the cython executable. + # build ray wheel + $PYTHON_EXE setup.py bdist_wheel + # build ray-cpp wheel + RAY_INSTALL_CPP=1 $PYTHON_EXE setup.py bdist_wheel + mv dist/*.whl ../.whl/ + popd +done From bab9c0f67018f2a101a0aeb775d60600775d9f63 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 1 Nov 2021 21:45:11 +0100 Subject: [PATCH 06/15] [RLlib; Docs overhaul] Redo: Docstring cleanup: Trainer, trainer_template, Callbacks."" (#19830) --- python/ray/tune/experiment.py | 2 +- rllib/agents/callbacks.py | 54 ++- rllib/agents/trainer.py | 629 +++++++++++++++++-------------- rllib/agents/trainer_template.py | 74 ++-- 4 files changed, 400 insertions(+), 359 deletions(-) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 01618be75caa..964f45fc3e8c 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -85,7 +85,7 @@ class Experiment: max_failures=2) """ - # keys that will be present in `public_spec` dict + # Keys that will be present in `public_spec` dict. PUBLIC_KEYS = {"stop", "num_samples"} def __init__(self, diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 561a98075fc7..3d48fc712c99 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -14,6 +14,7 @@ import psutil if TYPE_CHECKING: + from ray.rllib.agents.trainer import Trainer from ray.rllib.evaluation import RolloutWorker @@ -51,8 +52,6 @@ def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv, state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index: Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -73,19 +72,17 @@ def on_episode_step(self, """Runs on each episode step. Args: - worker (RolloutWorker): Reference to the current rollout worker. - base_env (BaseEnv): BaseEnv running the episode. The underlying + worker: Reference to the current rollout worker. + base_env: BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling `base_env.get_sub_environments()`. - policies (Optional[Dict[PolicyID, Policy]]): Mapping of policy id + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". - episode (Episode): Episode object which contains episode + episode: Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -101,19 +98,17 @@ def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, """Runs when an episode is done. Args: - worker (RolloutWorker): Reference to the current rollout worker. - base_env (BaseEnv): BaseEnv running the episode. The underlying + worker: Reference to the current rollout worker. + base_env: BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling `base_env.get_sub_environments()`. - policies (Dict[PolicyID, Policy]): Mapping of policy id to policy + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". - episode (Episode): Episode object which contains episode + episode: Episode object which contains episode state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. - env_index (EnvID): Obsoleted: The ID of the environment, which the - episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -136,16 +131,16 @@ def on_postprocess_trajectory( settings. Args: - worker (RolloutWorker): Reference to the current rollout worker. - episode (Episode): Episode object. - agent_id (str): Id of the current agent. - policy_id (str): Id of the current policy for the agent. - policies (dict): Mapping of policy id to policy objects. In single + worker: Reference to the current rollout worker. + episode: Episode object. + agent_id: Id of the current agent. + policy_id: Id of the current policy for the agent. + policies: Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". - postprocessed_batch (SampleBatch): The postprocessed sample batch + postprocessed_batch: The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing. - original_batches (dict): Mapping of agents to their unpostprocessed + original_batches: Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object. kwargs: Forward compatibility placeholder. """ @@ -164,8 +159,8 @@ def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, """Called at the end of RolloutWorker.sample(). Args: - worker (RolloutWorker): Reference to the current rollout worker. - samples (SampleBatch): Batch to be returned. You can mutate this + worker: Reference to the current rollout worker. + samples: Batch to be returned. You can mutate this object to modify the samples generated. kwargs: Forward compatibility placeholder. """ @@ -184,21 +179,22 @@ def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, `pad_batch_to_sequences_of_same_size`. Args: - policy (Policy): Reference to the current Policy object. - train_batch (SampleBatch): SampleBatch to be trained on. You can + policy: Reference to the current Policy object. + train_batch: SampleBatch to be trained on. You can mutate this object to modify the samples generated. - result (dict): A results dict to add custom metrics to. + result: A results dict to add custom metrics to. kwargs: Forward compatibility placeholder. """ pass - def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: + def on_train_result(self, *, trainer: "Trainer", result: dict, + **kwargs) -> None: """Called at the end of Trainable.train(). Args: - trainer (Trainer): Current trainer instance. - result (dict): Dict of results returned from trainer.train() call. + trainer: Current trainer instance. + result: Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics. kwargs: Forward compatibility placeholder. """ diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 375a3d1ac888..e0678a62fcb5 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -529,19 +529,34 @@ def with_common_config( @PublicAPI class Trainer(Trainable): - """A trainer coordinates the optimization of one or more RL policies. - - All RLlib trainers extend this base class, e.g., the A3CTrainer implements - the A3C algorithm for single and multi-agent training. - - Trainer objects retain internal model state between calls to train(), so - you should create a new trainer instance for each training session. - - Attributes: - env_creator (func): Function that creates a new training env. - config (obj): Algorithm-specific configuration data. - logdir (str): Directory in which training outputs should be placed. + """An RLlib algorithm responsible for optimizing one or more Policies. + + Trainers contain a WorkerSet under `self.workers`. A WorkerSet is + normally composed of a single local worker + (self.workers.local_worker()), used to compute and apply learning updates, + and optionally one or more remote workers (self.workers.remote_workers()), + used to generate environment samples in parallel. + + Each worker (remotes or local) contains a PolicyMap, which itself + may contain either one policy for single-agent training or one or more + policies for multi-agent training. Policies are synchronized + automatically from time to time using ray.remote calls. The exact + synchronization logic depends on the specific algorithm (Trainer) used, + but this usually happens from local worker to all remote workers and + after each training update. + + You can write your own Trainer sub-classes by using the + rllib.agents.trainer_template.py::build_trainer() utility function. + This allows you to provide a custom `execution_plan`. You can find the + different built-in algorithms' execution plans in their respective main + py files, e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py. + + The most important API methods a Trainer exposes are `train()`, + `evaluate()`, `save()` and `restore()`. Trainer objects retain internal + model state between calls to train(), so you should create a new + Trainer instance for each training session. """ + # Whether to allow unknown top-level config keys. _allow_unknown_configs = False @@ -562,15 +577,18 @@ class Trainer(Trainable): @PublicAPI def __init__(self, config: TrainerConfigDict = None, - env: str = None, + env: Union[str, EnvType, None] = None, logger_creator: Callable[[], Logger] = None): - """Initialize an RLLib trainer. + """Initializes a Trainer instance. Args: - config (dict): Algorithm-specific configuration data. - env (str): Name of the environment to use. Note that this can also - be specified as the `env` key in config. - logger_creator (func): Function that creates a ray.tune.Logger + config: Algorithm-specific configuration dict. + env: Name of the environment to use (e.g. a gym-registered str), + a full class path (e.g. + "ray.rllib.examples.env.random_env.RandomEnv"), or an Env + class directly. Note that this arg can also be specified via + the "env" key in `config`. + logger_creator: Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ @@ -623,151 +641,6 @@ def default_logger_creator(config): super().__init__(config, logger_creator) - @classmethod - @override(Trainable) - def default_resource_request( - cls, config: PartialTrainerConfigDict) -> \ - Union[Resources, PlacementGroupFactory]: - cf = dict(cls._default_config, **config) - - eval_config = cf["evaluation_config"] - - # TODO(ekl): add custom resources here once tune supports them - # Return PlacementGroupFactory containing all needed resources - # (already properly defined as device bundles). - return PlacementGroupFactory( - bundles=[{ - # Driver. - "CPU": cf["num_cpus_for_driver"], - "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], - }] + [ - { - # RolloutWorkers. - "CPU": cf["num_cpus_per_worker"], - "GPU": cf["num_gpus_per_worker"], - } for _ in range(cf["num_workers"]) - ] + ([ - { - # Evaluation workers. - # Note: The local eval worker is located on the driver CPU. - "CPU": eval_config.get("num_cpus_per_worker", - cf["num_cpus_per_worker"]), - "GPU": eval_config.get("num_gpus_per_worker", - cf["num_gpus_per_worker"]), - } for _ in range(cf["evaluation_num_workers"]) - ] if cf["evaluation_interval"] else []), - strategy=config.get("placement_strategy", "PACK")) - - @override(Trainable) - @PublicAPI - def train(self) -> ResultDict: - """Overrides super.train to synchronize global vars.""" - - result = None - for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): - try: - result = Trainable.train(self) - except RayError as e: - if self.config["ignore_worker_failures"]: - logger.exception( - "Error in train call, attempting to recover") - self._try_recover() - else: - logger.info( - "Worker crashed during call to train(). To attempt to " - "continue training without the failed worker, set " - "`'ignore_worker_failures': True`.") - raise e - except Exception as e: - time.sleep(0.5) # allow logs messages to propagate - raise e - else: - break - if result is None: - raise RuntimeError("Failed to recover from worker crash") - - if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): - self._sync_filters_if_needed(self.workers) - - return result - - def _sync_filters_if_needed(self, workers: WorkerSet): - if self.config.get("observation_filter", "NoFilter") != "NoFilter": - FilterManager.synchronize( - workers.local_worker().filters, - workers.remote_workers(), - update_remote=self.config["synchronize_filters"]) - logger.debug("synchronized filters: {}".format( - workers.local_worker().filters)) - - @override(Trainable) - def log_result(self, result: ResultDict): - self.callbacks.on_train_result(trainer=self, result=result) - # log after the callback is invoked, so that the user has a chance - # to mutate the result - Trainable.log_result(self, result) - - @DeveloperAPI - def _create_local_replay_buffer_if_necessary(self, config): - """Create a LocalReplayBuffer instance if necessary. - - Args: - config (dict): Algorithm-specific configuration data. - - Returns: - LocalReplayBuffer instance based on trainer config. - None, if local replay buffer is not needed. - """ - # These are the agents that utilizes a local replay buffer. - if ("replay_buffer_config" not in config - or not config["replay_buffer_config"]): - # Does not need a replay buffer. - return None - - replay_buffer_config = config["replay_buffer_config"] - if ("type" not in replay_buffer_config - or replay_buffer_config["type"] != "LocalReplayBuffer"): - # DistributedReplayBuffer coming soon. - return None - - capacity = config.get("buffer_size", DEPRECATED_VALUE) - if capacity != DEPRECATED_VALUE: - # Print a deprecation warning. - deprecation_warning( - old="config['buffer_size']", - new="config['replay_buffer_config']['capacity']", - error=False) - else: - # Get capacity out of replay_buffer_config. - capacity = replay_buffer_config["capacity"] - - if config.get("prioritized_replay"): - prio_args = { - "prioritized_replay_alpha": config["prioritized_replay_alpha"], - "prioritized_replay_beta": config["prioritized_replay_beta"], - "prioritized_replay_eps": config["prioritized_replay_eps"], - } - else: - prio_args = {} - - return LocalReplayBuffer( - num_shards=1, - learning_starts=config["learning_starts"], - capacity=capacity, - replay_batch_size=config["train_batch_size"], - replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config.get("replay_sequence_length", 1), - replay_burn_in=config.get("burn_in", 0), - replay_zero_init_states=config.get("zero_init_states", True), - **prio_args) - - @DeveloperAPI - def _kwargs_for_execution_plan(self): - kwargs = {} - if self.local_replay_buffer: - kwargs["local_replay_buffer"] = self.local_replay_buffer - return kwargs - @override(Trainable) def setup(self, config: PartialTrainerConfigDict): env = self._env_id @@ -839,6 +712,8 @@ def env_creator_from_classpath(env_context): self.local_replay_buffer = ( self._create_local_replay_buffer_if_necessary(self.config)) + # Make the call to self._init. Sub-classes should override this + # method to implement custom initialization logic. self._init(self.config, self.env_creator) # Evaluation setup. @@ -875,69 +750,53 @@ def env_creator_from_classpath(env_context): config=evaluation_config, num_workers=self.config["evaluation_num_workers"]) - @override(Trainable) - def cleanup(self): - if hasattr(self, "workers"): - self.workers.stop() - if hasattr(self, "optimizer") and self.optimizer: - self.optimizer.stop() - - @override(Trainable) - def save_checkpoint(self, checkpoint_dir: str) -> str: - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) - - return checkpoint_path - - @override(Trainable) - def load_checkpoint(self, checkpoint_path: str): - extra_data = pickle.load(open(checkpoint_path, "rb")) - self.__setstate__(extra_data) - @DeveloperAPI - def _make_workers( - self, *, env_creator: Callable[[EnvContext], EnvType], - validate_env: Optional[Callable[[EnvType, EnvContext], None]], - policy_class: Type[Policy], config: TrainerConfigDict, - num_workers: int) -> WorkerSet: - """Default factory method for a WorkerSet running under this Trainer. + def _init(self, config: TrainerConfigDict, + env_creator: Callable[[EnvContext], EnvType]) -> None: + """Subclasses should override this for custom initialization. - Override this method by passing a custom `make_workers` into - `build_trainer`. + In the case of Trainer, this is called from inside `self.setup()`. Args: - env_creator (callable): A function that return and Env given an env - config. - validate_env (Optional[Callable[[EnvType, EnvContext], None]]): - Optional callable to validate the generated environment (only - on worker=0). - policy (Type[Policy]): The Policy class to use for creating the - policies of the workers. - config (TrainerConfigDict): The Trainer's config. - num_workers (int): Number of remote rollout workers to create. - 0 for local only. - - Returns: - WorkerSet: The created WorkerSet. + config: Algorithm-specific configuration dict. + env_creator: A callable taking an EnvContext as only arg and + returning an environment (of any type: e.g. gym.Env, RLlib + BaseEnv, MultiAgentEnv, etc..). """ - return WorkerSet( - env_creator=env_creator, - validate_env=validate_env, - policy_class=policy_class, - trainer_config=config, - num_workers=num_workers, - logdir=self.logdir) - - @DeveloperAPI - def _init(self, config: TrainerConfigDict, - env_creator: Callable[[EnvContext], EnvType]): - """Subclasses should override this for custom initialization.""" raise NotImplementedError - @Deprecated(new="Trainer.evaluate", error=False) - def _evaluate(self) -> dict: - return self.evaluate() + @override(Trainable) + @PublicAPI + def train(self) -> ResultDict: + """Overrides super.train to synchronize global vars.""" + + result = None + for _ in range(1 + MAX_WORKER_FAILURE_RETRIES): + try: + result = Trainable.train(self) + except RayError as e: + if self.config["ignore_worker_failures"]: + logger.exception( + "Error in train call, attempting to recover") + self._try_recover() + else: + logger.info( + "Worker crashed during call to train(). To attempt to " + "continue training without the failed worker, set " + "`'ignore_worker_failures': True`.") + raise e + except Exception as e: + time.sleep(0.5) # allow logs messages to propagate + raise e + else: + break + if result is None: + raise RuntimeError("Failed to recover from worker crash") + + if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): + self._sync_filters_if_needed(self.workers) + + return result @PublicAPI def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None @@ -948,10 +807,10 @@ def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None merging evaluation_config with the normal trainer config. Args: - episodes_left_fn (Optional[Callable[[int], int]]): An optional - callable taking the already run num episodes as only arg - and returning the number of episodes left to run. It's used - to find out whether evaluation should continue. + episodes_left_fn: An optional callable taking the already run + num episodes as only arg and returning the number of + episodes left to run. It's used to find out whether + evaluation should continue. """ # In case we are evaluating (in a thread) parallel to training, # we may have to re-enable eager mode here (gets disabled in the @@ -963,8 +822,8 @@ def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None # Call the `_before_evaluate` hook. self._before_evaluate() + # Sync weights to the evaluation WorkerSet. if self.evaluation_workers is not None: - # Sync weights to the evaluation WorkerSet. self._sync_weights_to_workers(worker_set=self.evaluation_workers) self._sync_filters_if_needed(self.evaluation_workers) @@ -1053,25 +912,6 @@ def episodes_left_fn(num_episodes_done): self.evaluation_workers.remote_workers()) return {"evaluation": metrics} - @DeveloperAPI - def _before_evaluate(self): - """Pre-evaluation callback.""" - pass - - @DeveloperAPI - def _sync_weights_to_workers( - self, - *, - worker_set: Optional[WorkerSet] = None, - workers: Optional[List[RolloutWorker]] = None, - ) -> None: - """Sync "main" weights to given WorkerSet or list of workers.""" - assert worker_set is not None - # Broadcast the new policy weights to all evaluation workers. - logger.info("Synchronizing weights to workers.") - weights = ray.put(self.workers.local_worker().save()) - worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) - @PublicAPI def compute_single_action( self, @@ -1223,10 +1063,6 @@ def compute_single_action( else: return action - @Deprecated(new="compute_single_action", error=False) - def compute_action(self, *args, **kwargs): - return self.compute_single_action(*args, **kwargs) - @PublicAPI def compute_actions( self, @@ -1253,7 +1089,7 @@ def compute_actions( self.get_policy(policy_id) and call compute_actions() on it directly. Args: - observation: observation from the environment. + observation: Observation from the environment. state: RNN hidden state, if any. If state is not None, then all of compute_single_action(...) is returned (computed action, rnn state(s), logits dictionary). @@ -1284,7 +1120,7 @@ def compute_actions( Returns: any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if - full_fetch=True or we have an RNN-based Policy. + full_fetch=True or we have an RNN-based Policy. """ if normalize_actions is not None: deprecation_warning( @@ -1359,31 +1195,21 @@ def compute_actions( else: return actions - @property - def _name(self) -> str: - """Subclasses should override this to declare their name.""" - raise NotImplementedError - - @property - def _default_config(self) -> TrainerConfigDict: - """Subclasses should override this to declare their default config.""" - raise NotImplementedError - @PublicAPI def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy: """Return policy for the specified id, or None. Args: - policy_id (PolicyID): ID of the policy to return. + policy_id: ID of the policy to return. """ return self.workers.local_worker().get_policy(policy_id) @PublicAPI - def get_weights(self, policies: List[PolicyID] = None) -> dict: + def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict: """Return a dictionary of policy ids to weights. Args: - policies (list): Optional list of policies to return weights for, + policies: Optional list of policies to return weights for, or None for all policies. """ return self.workers.local_worker().get_weights(policies) @@ -1393,7 +1219,7 @@ def set_weights(self, weights: Dict[PolicyID, dict]): """Set policy weights by policy id. Args: - weights (dict): Map of policy ids to weights to set. + weights: Map of policy ids to weights to set. """ self.workers.local_worker().set_weights(weights) @@ -1502,35 +1328,38 @@ def fn(worker): def export_policy_model(self, export_dir: str, policy_id: PolicyID = DEFAULT_POLICY_ID, - onnx: Optional[int] = None): - """Export policy model with given policy_id to local directory. + onnx: Optional[int] = None) -> None: + """Exports policy model with given policy_id to a local directory. Args: - export_dir (string): Writable local directory. - policy_id (string): Optional policy id to export. - onnx (int): If given, will export model in ONNX format. The + export_dir: Writable local directory. + policy_id: Optional policy id to export. + onnx: If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. + If None, the output format will be DL framework specific. Example: >>> trainer = MyTrainer() >>> for _ in range(10): >>> trainer.train() - >>> trainer.export_policy_model("/tmp/export_dir") + >>> trainer.export_policy_model("/tmp/dir") + >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1) """ - self.workers.local_worker().export_policy_model( - export_dir, policy_id, onnx) + self.get_policy(policy_id).export_model(export_dir, onnx) @DeveloperAPI - def export_policy_checkpoint(self, - export_dir: str, - filename_prefix: str = "model", - policy_id: PolicyID = DEFAULT_POLICY_ID): - """Export tensorflow policy model checkpoint to local directory. + def export_policy_checkpoint( + self, + export_dir: str, + filename_prefix: str = "model", + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: + """Exports policy model checkpoint to a local directory. Args: - export_dir (string): Writable local directory. - filename_prefix (string): file name prefix of checkpoint files. - policy_id (string): Optional policy id to export. + export_dir: Writable local directory. + filename_prefix: file name prefix of checkpoint files. + policy_id: Optional policy id to export. Example: >>> trainer = MyTrainer() @@ -1538,18 +1367,20 @@ def export_policy_checkpoint(self, >>> trainer.train() >>> trainer.export_policy_checkpoint("/tmp/export_dir") """ - self.workers.local_worker().export_policy_checkpoint( - export_dir, filename_prefix, policy_id) + self.get_policy(policy_id).export_checkpoint(export_dir, + filename_prefix) @DeveloperAPI - def import_policy_model_from_h5(self, - import_file: str, - policy_id: PolicyID = DEFAULT_POLICY_ID): + def import_policy_model_from_h5( + self, + import_file: str, + policy_id: PolicyID = DEFAULT_POLICY_ID, + ) -> None: """Imports a policy's model with given policy_id from a local h5 file. Args: - import_file (str): The h5 file to import from. - policy_id (string): Optional policy id to import into. + import_file: The h5 file to import from. + policy_id: Optional policy id to import into. Example: >>> trainer = MyTrainer() @@ -1557,8 +1388,9 @@ def import_policy_model_from_h5(self, >>> for _ in range(10): >>> trainer.train() """ - self.workers.local_worker().import_policy_model_from_h5( - import_file, policy_id) + self.get_policy(policy_id).import_model_from_h5(import_file) + # Sync new weights to remote workers. + self._sync_weights_to_workers() @DeveloperAPI def collect_metrics(self, @@ -1572,6 +1404,156 @@ def collect_metrics(self, min_history=self.config["metrics_smoothing_episodes"], selected_workers=selected_workers) + @override(Trainable) + def save_checkpoint(self, checkpoint_dir: str) -> str: + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + + return checkpoint_path + + @override(Trainable) + def load_checkpoint(self, checkpoint_path: str) -> None: + extra_data = pickle.load(open(checkpoint_path, "rb")) + self.__setstate__(extra_data) + + @override(Trainable) + def log_result(self, result: ResultDict) -> None: + # Log after the callback is invoked, so that the user has a chance + # to mutate the result. + self.callbacks.on_train_result(trainer=self, result=result) + # Then log according to Trainable's logging logic. + Trainable.log_result(self, result) + + @override(Trainable) + def cleanup(self) -> None: + # Stop all workers. + if hasattr(self, "workers"): + self.workers.stop() + # Stop all optimizers. + if hasattr(self, "optimizer") and self.optimizer: + self.optimizer.stop() + + @classmethod + @override(Trainable) + def default_resource_request( + cls, config: PartialTrainerConfigDict) -> \ + Union[Resources, PlacementGroupFactory]: + + # Default logic for RLlib algorithms (Trainers): + # Create one bundle per individual worker (local or remote). + # Use `num_cpus_for_driver` and `num_gpus` for the local worker and + # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote + # workers to determine their CPU/GPU resource needs. + + # Convenience config handles. + cf = dict(cls._default_config, **config) + eval_cf = cf["evaluation_config"] + + # TODO(ekl): add custom resources here once tune supports them + # Return PlacementGroupFactory containing all needed resources + # (already properly defined as device bundles). + return PlacementGroupFactory( + bundles=[{ + # Local worker. + "CPU": cf["num_cpus_for_driver"], + "GPU": 0 if cf["_fake_gpus"] else cf["num_gpus"], + }] + [ + { + # RolloutWorkers. + "CPU": cf["num_cpus_per_worker"], + "GPU": cf["num_gpus_per_worker"], + } for _ in range(cf["num_workers"]) + ] + ([ + { + # Evaluation workers. + # Note: The local eval worker is located on the driver CPU. + "CPU": eval_cf.get("num_cpus_per_worker", + cf["num_cpus_per_worker"]), + "GPU": eval_cf.get("num_gpus_per_worker", + cf["num_gpus_per_worker"]), + } for _ in range(cf["evaluation_num_workers"]) + ] if cf["evaluation_interval"] else []), + strategy=config.get("placement_strategy", "PACK")) + + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + + @DeveloperAPI + def _make_workers( + self, + *, + env_creator: Callable[[EnvContext], EnvType], + validate_env: Optional[Callable[[EnvType, EnvContext], None]], + policy_class: Type[Policy], + config: TrainerConfigDict, + num_workers: int, + ) -> WorkerSet: + """Default factory method for a WorkerSet running under this Trainer. + + Override this method by passing a custom `make_workers` into + `build_trainer`. + + Args: + env_creator: A function that return and Env given an env + config. + validate_env: Optional callable to validate the generated + environment. The env to be checked is the one returned from + the env creator, which may be a (single, not-yet-vectorized) + gym.Env or your custom RLlib env type (e.g. MultiAgentEnv, + VectorEnv, BaseEnv, etc..). + policy_class: The Policy class to use for creating the policies + of the workers. + config: The Trainer's config. + num_workers: Number of remote rollout workers to create. + 0 for local only. + + Returns: + The created WorkerSet. + """ + return WorkerSet( + env_creator=env_creator, + validate_env=validate_env, + policy_class=policy_class, + trainer_config=config, + num_workers=num_workers, + logdir=self.logdir) + + def _sync_filters_if_needed(self, workers: WorkerSet): + if self.config.get("observation_filter", "NoFilter") != "NoFilter": + FilterManager.synchronize( + workers.local_worker().filters, + workers.remote_workers(), + update_remote=self.config["synchronize_filters"]) + logger.debug("synchronized filters: {}".format( + workers.local_worker().filters)) + + @DeveloperAPI + def _sync_weights_to_workers( + self, + *, + worker_set: Optional[WorkerSet] = None, + workers: Optional[List[RolloutWorker]] = None, + ) -> None: + """Sync "main" weights to given WorkerSet or list of workers.""" + assert worker_set is not None + # Broadcast the new policy weights to all evaluation workers. + logger.info("Synchronizing weights to workers.") + weights = ray.put(self.workers.local_worker().save()) + worker_set.foreach_worker(lambda w: w.restore(ray.get(weights))) + + @property + def _name(self) -> str: + """Subclasses should override this to declare their name.""" + raise NotImplementedError + + @property + def _default_config(self) -> TrainerConfigDict: + """Subclasses should override this to declare their default config.""" + raise NotImplementedError + @classmethod @override(Trainable) def resource_help(cls, config: TrainerConfigDict) -> str: @@ -1909,6 +1891,69 @@ def with_updates(**overrides) -> Type["Trainer"]: "that were generated via the `ray.rllib.agents.trainer_template." "build_trainer()` function!") + @DeveloperAPI + def _create_local_replay_buffer_if_necessary( + self, + config: PartialTrainerConfigDict) -> Optional[LocalReplayBuffer]: + """Create a LocalReplayBuffer instance if necessary. + + Args: + config: Algorithm-specific configuration data. + + Returns: + LocalReplayBuffer instance based on trainer config. + None, if local replay buffer is not needed. + """ + # These are the agents that utilizes a local replay buffer. + if ("replay_buffer_config" not in config + or not config["replay_buffer_config"]): + # Does not need a replay buffer. + return None + + replay_buffer_config = config["replay_buffer_config"] + if ("type" not in replay_buffer_config + or replay_buffer_config["type"] != "LocalReplayBuffer"): + # DistributedReplayBuffer coming soon. + return None + + capacity = config.get("buffer_size", DEPRECATED_VALUE) + if capacity != DEPRECATED_VALUE: + # Print a deprecation warning. + deprecation_warning( + old="config['buffer_size']", + new="config['replay_buffer_config']['capacity']", + error=False) + else: + # Get capacity out of replay_buffer_config. + capacity = replay_buffer_config["capacity"] + + if config.get("prioritized_replay"): + prio_args = { + "prioritized_replay_alpha": config["prioritized_replay_alpha"], + "prioritized_replay_beta": config["prioritized_replay_beta"], + "prioritized_replay_eps": config["prioritized_replay_eps"], + } + else: + prio_args = {} + + return LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + capacity=capacity, + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), + **prio_args) + + @DeveloperAPI + def _kwargs_for_execution_plan(self): + kwargs = {} + if self.local_replay_buffer: + kwargs["local_replay_buffer"] = self.local_replay_buffer + return kwargs + def _register_if_needed(self, env_object: Union[str, EnvType, None], config) -> Optional[str]: if isinstance(env_object, str): @@ -1939,5 +1984,13 @@ def _is_multi_agent(self): "You can specify a custom env as either a class " "(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").") + @Deprecated(new="Trainer.evaluate", error=False) + def _evaluate(self) -> dict: + return self.evaluate() + + @Deprecated(new="compute_single_action", error=False) + def compute_action(self, *args, **kwargs): + return self.compute_single_action(*args, **kwargs) + def __repr__(self): return self._name diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index b3a8ff71c29c..450d91cefbbf 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -75,58 +75,50 @@ def build_trainer( allow_unknown_subkeys: Optional[List[str]] = None, override_all_subkeys_if_type_changes: Optional[List[str]] = None, ) -> Type[Trainer]: - """Helper function for defining a custom trainer. + """Helper function for defining a custom Trainer class. Functions will be run in this order to initialize the trainer: - 1. Config setup: validate_config, get_policy - 2. Worker setup: before_init, execution_plan - 3. Post setup: after_init + 1. Config setup: validate_config, get_policy. + 2. Worker setup: before_init, execution_plan. + 3. Post setup: after_init. Args: - name (str): name of the trainer (e.g., "PPO") - default_config (Optional[TrainerConfigDict]): The default config dict - of the algorithm, otherwise uses the Trainer default config. - validate_config (Optional[Callable[[TrainerConfigDict], None]]): - Optional callable that takes the config to check for correctness. - It may mutate the config as needed. - default_policy (Optional[Type[Policy]]): The default Policy class to - use if `get_policy_class` returns None. - get_policy_class (Optional[Callable[ - TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable - that takes a config and returns the policy class or None. If None - is returned, will use `default_policy` (which must be provided - then). - validate_env (Optional[Callable[[EnvType, EnvContext], None]]): - Optional callable to validate the generated environment (only - on worker=0). - before_init (Optional[Callable[[Trainer], None]]): Optional callable to - run before anything is constructed inside Trainer (Workers with - Policies, execution plan, etc..). Takes the Trainer instance as - argument. - after_init (Optional[Callable[[Trainer], None]]): Optional callable to - run at the end of trainer init (after all Workers and the exec. - plan have been constructed). Takes the Trainer instance as - argument. - before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to - run before evaluation. This takes the trainer instance as argument. - mixins (list): list of any class mixins for the returned trainer class. + name: name of the trainer (e.g., "PPO") + default_config: The default config dict of the algorithm, + otherwise uses the Trainer default config. + validate_config: Optional callable that takes the config to check + for correctness. It may mutate the config as needed. + default_policy: The default Policy class to use if `get_policy_class` + returns None. + get_policy_class: Optional callable that takes a config and returns + the policy class or None. If None is returned, will use + `default_policy` (which must be provided then). + validate_env: Optional callable to validate the generated environment + (only on worker=0). + before_init: Optional callable to run before anything is constructed + inside Trainer (Workers with Policies, execution plan, etc..). + Takes the Trainer instance as argument. + after_init: Optional callable to run at the end of trainer init + (after all Workers and the exec. plan have been constructed). + Takes the Trainer instance as argument. + before_evaluate_fn: Callback to run before evaluation. This takes + the trainer instance as argument. + mixins: List of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. - execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict], - Iterable[ResultDict]]]): Optional callable that sets up the + execution_plan: Optional callable that sets up the distributed execution workflow. - allow_unknown_configs (bool): Whether to allow unknown top-level config - keys. - allow_unknown_subkeys (Optional[List[str]]): List of top-level keys + allow_unknown_configs: Whether to allow unknown top-level config keys. + allow_unknown_subkeys: List of top-level keys with value=dict, for which new sub-keys are allowed to be added to the value dict. Appends to Trainer class defaults. - override_all_subkeys_if_type_changes (Optional[List[str]]): List of top - level keys with value=dict, for which we always override the entire - value (dict), iff the "type" key in that value dict changes. - Appends to Trainer class defaults. + override_all_subkeys_if_type_changes: List of top level keys with + value=dict, for which we always override the entire value (dict), + iff the "type" key in that value dict changes. Appends to Trainer + class defaults. Returns: - Type[Trainer]: A Trainer sub-class configured by the specified args. + A Trainer sub-class configured by the specified args. """ original_kwargs = locals().copy() From 0b308719f88317ef6d9472a1fe361311236a7015 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 1 Nov 2021 21:46:02 +0100 Subject: [PATCH 07/15] [RLlib; Docs overhaul] Docstring cleanup: rllib/utils (#19829) --- rllib/agents/a3c/a3c_tf_policy.py | 4 +- rllib/agents/a3c/a3c_torch_policy.py | 2 +- rllib/agents/ars/ars.py | 3 +- rllib/agents/ddpg/ddpg_tf_policy.py | 2 +- rllib/agents/dqn/dqn_tf_policy.py | 4 +- rllib/agents/dqn/learner_thread.py | 2 +- rllib/agents/dqn/r2d2_tf_policy.py | 2 +- rllib/agents/dqn/simple_q_tf_policy.py | 2 +- rllib/agents/es/es.py | 3 +- rllib/agents/impala/vtrace_tf_policy.py | 2 +- rllib/agents/marwil/marwil_tf_policy.py | 2 +- rllib/agents/ppo/appo_tf_policy.py | 2 +- rllib/agents/ppo/ppo_tf_policy.py | 6 +- rllib/agents/registry.py | 2 +- rllib/agents/sac/sac_tf_policy.py | 2 +- rllib/agents/trainer.py | 5 +- rllib/evaluation/rollout_worker.py | 2 +- rllib/evaluation/sample_batch_builder.py | 3 +- rllib/examples/centralized_critic.py | 2 +- rllib/examples/env/multi_agent.py | 2 +- .../trajectory_view_utilizing_models.py | 2 +- rllib/execution/learner_thread.py | 2 +- rllib/execution/replay_buffer.py | 6 +- rllib/models/catalog.py | 4 +- rllib/models/modelv2.py | 3 +- rllib/models/tf/attention_net.py | 2 +- rllib/models/tf/complex_input_net.py | 2 +- rllib/models/tf/recurrent_net.py | 2 +- rllib/models/torch/torch_action_dist.py | 3 +- rllib/policy/dynamic_tf_policy.py | 2 +- rllib/policy/eager_tf_policy.py | 4 +- rllib/policy/policy.py | 3 +- rllib/policy/policy_map.py | 2 +- rllib/policy/sample_batch.py | 6 +- rllib/policy/tf_policy.py | 6 +- rllib/policy/torch_policy.py | 8 +- rllib/policy/torch_policy_template.py | 2 +- rllib/utils/annotations.py | 68 +-- rllib/utils/debug.py | 14 +- rllib/utils/deprecation.py | 65 ++- rllib/utils/exploration/curiosity.py | 2 +- rllib/utils/exploration/exploration.py | 3 +- rllib/utils/exploration/gaussian_noise.py | 2 +- .../exploration/ornstein_uhlenbeck_noise.py | 2 +- rllib/utils/exploration/random.py | 2 +- .../utils/exploration/stochastic_sampling.py | 2 +- rllib/utils/framework.py | 40 +- rllib/utils/memory.py | 81 +--- rllib/utils/metrics/window_stat.py | 28 ++ rllib/utils/multi_agent.py | 7 +- rllib/utils/numpy.py | 392 +++++++++------- rllib/utils/test_utils.py | 10 +- rllib/utils/tf_ops.py | 313 +------------ rllib/utils/tf_utils.py | 426 ++++++++++++++++++ rllib/utils/threading.py | 6 +- rllib/utils/torch_ops.py | 263 +---------- rllib/utils/torch_utils.py | 395 ++++++++++++++++ rllib/utils/window_stat.py | 37 +- 58 files changed, 1301 insertions(+), 970 deletions(-) create mode 100644 rllib/utils/metrics/window_stat.py create mode 100644 rllib/utils/tf_utils.py create mode 100644 rllib/utils/torch_utils.py diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index 78ea8cf8e28d..0dd87a9b233f 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -14,9 +14,9 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance +from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ PolicyID, LocalOptimizer, ModelGradients diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 10508d073f63..557d7eb53fc4 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ EntropyCoeffSchedule -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import apply_grad_clipping, sequence_mask from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ diff --git a/rllib/agents/ars/ars.py b/rllib/agents/ars/ars.py index ef83ea6ad247..bf55e38fa8af 100644 --- a/rllib/agents/ars/ars.py +++ b/rllib/agents/ars/ars.py @@ -16,7 +16,8 @@ from ray.rllib.agents.es.es_tf_policy import rollout from ray.rllib.env.env_context import EnvContext from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import Deprecated, override +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.torch_ops import set_torch_seed from ray.rllib.utils import FilterManager diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index d3c295feba94..be185f6d634d 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -26,7 +26,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import get_variable, try_import_tf from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable +from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ LocalOptimizer, ModelGradients from ray.util.debug import log_once diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 88e3e30e06fa..d24ee4477392 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -20,8 +20,8 @@ from ray.rllib.utils.exploration import ParameterNoise from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.tf_ops import (huber_loss, make_tf_callable, - minimize_and_clip, reduce_mean_ignore_inf) +from ray.rllib.utils.tf_utils import ( + huber_loss, make_tf_callable, minimize_and_clip, reduce_mean_ignore_inf) from ray.rllib.utils.typing import (ModelGradients, TensorType, TrainerConfigDict) diff --git a/rllib/agents/dqn/learner_thread.py b/rllib/agents/dqn/learner_thread.py index 93bed4b18de5..6e7b1ebae348 100644 --- a/rllib/agents/dqn/learner_thread.py +++ b/rllib/agents/dqn/learner_thread.py @@ -3,8 +3,8 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder +from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.window_stat import WindowStat LEARNER_QUEUE_MAX_SIZE = 16 diff --git a/rllib/agents/dqn/r2d2_tf_policy.py b/rllib/agents/dqn/r2d2_tf_policy.py index d34c35a44976..2f922a7dfd03 100644 --- a/rllib/agents/dqn/r2d2_tf_policy.py +++ b/rllib/agents/dqn/r2d2_tf_policy.py @@ -17,7 +17,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import huber_loss +from ray.rllib.utils.tf_utils import huber_loss from ray.rllib.utils.typing import ModelInputDict, TensorType, \ TrainerConfigDict diff --git a/rllib/agents/dqn/simple_q_tf_policy.py b/rllib/agents/dqn/simple_q_tf_policy.py index 13e62bca1fd9..49674c5e752e 100644 --- a/rllib/agents/dqn/simple_q_tf_policy.py +++ b/rllib/agents/dqn/simple_q_tf_policy.py @@ -18,7 +18,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable +from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable from ray.rllib.utils.typing import TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/es/es.py b/rllib/agents/es/es.py index 796076b01d9d..3535bf245926 100644 --- a/rllib/agents/es/es.py +++ b/rllib/agents/es/es.py @@ -14,7 +14,8 @@ from ray.rllib.env.env_context import EnvContext from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import FilterManager -from ray.rllib.utils.annotations import Deprecated, override +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.torch_ops import set_torch_seed logger = logging.getLogger(__name__) diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 5a786a4da8e9..4e9594588b66 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -15,7 +15,7 @@ EntropyCoeffSchedule from ray.rllib.utils import force_list from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance +from ray.rllib.utils.tf_utils import explained_variance tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/marwil/marwil_tf_policy.py b/rllib/agents/marwil/marwil_tf_policy.py index 9a386671792e..2c0f25913236 100644 --- a/rllib/agents/marwil/marwil_tf_policy.py +++ b/rllib/agents/marwil/marwil_tf_policy.py @@ -9,7 +9,7 @@ Postprocessing from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.framework import try_import_tf, get_variable -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.policy.policy import Policy from ray.rllib.utils.typing import TrainerConfigDict, TensorType, \ PolicyID diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index c579d21c3123..98af35d13a01 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -28,7 +28,7 @@ from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.typing import AgentID, TensorType, TrainerConfigDict tf1, tf, tfv = try_import_tf() diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 76610567abcb..6fc5ac27d88e 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -17,10 +17,10 @@ from ray.rllib.policy.tf_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils.annotations import Deprecated -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \ + deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index 6139afbb17f0..f5cfe0b41748 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -3,7 +3,7 @@ import traceback from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated def _import_a2c(): diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index c6b5be01b6bf..97dbad921b1d 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -27,7 +27,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import get_variable, try_import_tf from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.tf_ops import huber_loss +from ray.rllib.utils.tf_utils import huber_loss from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e0678a62fcb5..323bc01b3e6d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -28,10 +28,11 @@ from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils import deep_update, FilterManager, merge_dicts -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override, \ +from ray.rllib.utils.annotations import DeveloperAPI, override, \ PublicAPI from ray.rllib.utils.debug import update_global_seed_if_necessary -from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning, \ + DEPRECATED_VALUE from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.from_config import from_config diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 8fd1ebcf6208..360005c3c808 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -39,7 +39,7 @@ from ray.rllib.utils.filter import get_filter, Filter from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.sgd import do_minibatch_sgd -from ray.rllib.utils.tf_ops import get_gpu_devices as get_tf_gpu_devices +from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices from ray.rllib.utils.tf_run_builder import TFRunBuilder from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ ModelConfigDict, ModelGradients, ModelWeights, \ diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 7afeb14816f2..898970c308a7 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -7,7 +7,8 @@ from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.typing import PolicyID, AgentID diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 09fdb3b968de..60b968404138 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -38,7 +38,7 @@ EntropyCoeffSchedule as TorchEntropyCoeffSchedule from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable from ray.rllib.utils.torch_ops import convert_to_torch_tensor tf1, tf, tfv = try_import_tf() diff --git a/rllib/examples/env/multi_agent.py b/rllib/examples/env/multi_agent.py index 4e052e70ecb5..a1d19eea1bc5 100644 --- a/rllib/examples/env/multi_agent.py +++ b/rllib/examples/env/multi_agent.py @@ -4,7 +4,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv, make_multi_agent from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2 from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated @Deprecated( diff --git a/rllib/examples/models/trajectory_view_utilizing_models.py b/rllib/examples/models/trajectory_view_utilizing_models.py index 0fd4e22cb145..c38ba2a6f7e2 100644 --- a/rllib/examples/models/trajectory_view_utilizing_models.py +++ b/rllib/examples/models/trajectory_view_utilizing_models.py @@ -3,7 +3,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot from ray.rllib.utils.torch_ops import one_hot as torch_one_hot tf1, tf, tfv = try_import_tf() diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index d8c6f93c146b..3fb8a3195eda 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -8,8 +8,8 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \ LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.window_stat import WindowStat from ray.util.iter import _NextValueNotReady tf1, tf, tfv = try_import_tf() diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 3a2e461cd38b..3c07272760d0 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -17,10 +17,10 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.util.iter import ParallelIteratorWorker from ray.util.debug import log_once -from ray.rllib.utils.annotations import Deprecated -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \ + deprecation_warning from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.window_stat import WindowStat +from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.typing import SampleBatchType # Constant that represents all policies in lockstep replay mode. diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 180b08e2c665..048ca478c2eb 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -18,8 +18,8 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchDeterministic, TorchDiagGaussian, \ TorchMultiActionDistribution, TorchMultiCategorical -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \ +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \ deprecation_warning from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.framework import try_import_tf, try_import_torch diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index db234dc4247e..971fc952c593 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -10,7 +10,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import NullContextManager -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 05e28e989eaf..5d0fe6aafb68 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -22,7 +22,7 @@ from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType, List tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/tf/complex_input_net.py b/rllib/models/tf/complex_input_net.py index c7323c41cab9..8b4ffa801f4b 100644 --- a/rllib/models/tf/complex_input_net.py +++ b/rllib/models/tf/complex_input_net.py @@ -11,7 +11,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.space_utils import flatten_space -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 862763304e63..aa68808bbbbf 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -11,7 +11,7 @@ from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import one_hot +from ray.rllib.utils.tf_utils import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 43a75b281003..0487bcf19de2 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -11,7 +11,6 @@ from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \ MAX_LOG_NN_OUTPUT from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.torch_ops import atanh from ray.rllib.utils.typing import TensorType, List, Union, \ Tuple, ModelConfigDict @@ -300,7 +299,7 @@ def _unsquash(self, values: TensorType) -> TensorType: # Stabilize input to atanh. save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER) - unsquashed = atanh(save_normed_values) + unsquashed = torch.atanh(save_normed_values) return unsquashed @staticmethod diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index acb272f3bfa1..3f5ac3044a2b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -17,7 +17,7 @@ from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import get_placeholder +from ray.rllib.utils.tf_utils import get_placeholder from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 75563663c357..7248bd9feb91 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -20,7 +20,7 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.spaces.space_utils import normalize_action -from ray.rllib.utils.tf_ops import get_gpu_devices +from ray.rllib.utils.tf_utils import get_gpu_devices from ray.rllib.utils.threading import with_lock from ray.rllib.utils.typing import LocalOptimizer, TensorType @@ -724,7 +724,7 @@ def get_session(self): def get_placeholder(self, ph): raise ValueError( "get_placeholder() is not allowed in eager mode. Try using " - "rllib.utils.tf_ops.make_tf_callable() to write " + "rllib.utils.tf_utils.make_tf_callable() to write " "functions that work in both graph and eager mode.") def loss_initialized(self): diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 7e67624b7fe5..1fd9b70e8d93 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -10,7 +10,8 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.exploration.exploration import Exploration from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index abdf591c6efe..ad7e940f59d3 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -8,7 +8,7 @@ from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import get_tf_eager_cls_if_necessary +from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.threading import with_lock from ray.rllib.utils.typing import PartialTrainerConfigDict, \ PolicyID, TrainerConfigDict diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 389278a1a432..4dbe3436cc0c 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -6,12 +6,12 @@ from typing import Dict, Iterator, List, Optional, Set, Union from ray.util import log_once -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, \ +from ray.rllib.utils.annotations import DeveloperAPI, \ PublicAPI from ray.rllib.utils.compression import pack, unpack, is_compressed -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.memory import concat_aligned +from ray.rllib.utils.numpy import concat_aligned from ray.rllib.utils.typing import PolicyID, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index c4f67f8e17ef..7f3b3fd81b29 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -15,14 +15,14 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils import force_list -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, override +from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.debug import summarize -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated, deprecation_warning from ray.rllib.utils.framework import try_import_tf, get_variable from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action -from ray.rllib.utils.tf_ops import get_gpu_devices +from ray.rllib.utils.tf_utils import get_gpu_devices from ray.rllib.utils.tf_run_builder import TFRunBuilder from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \ TensorType, TrainerConfigDict diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 54df8d487fbc..7ea18771b44d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -22,11 +22,11 @@ from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.spaces.space_utils import normalize_action from ray.rllib.utils.threading import with_lock -from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ - convert_to_torch_tensor +from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \ TensorStructType, TrainerConfigDict @@ -671,7 +671,7 @@ def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: state = super().get_state() state["_optimizer_variables"] = [] for i, o in enumerate(self._optimizers): - optim_state_dict = convert_to_non_torch_type(o.state_dict()) + optim_state_dict = convert_to_numpy(o.state_dict()) state["_optimizer_variables"].append(optim_state_dict) # Add exploration state. state["_exploration_state"] = \ @@ -940,7 +940,7 @@ def _compute_action_helper(self, input_dict, state_batches, seq_lens, # Update our global timestep by the batch size. self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) - return convert_to_non_torch_type((actions, state_out, extra_fetches)) + return convert_to_numpy((actions, state_out, extra_fetches)) def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): # TODO: (sven): Keep for a while to ensure backward compatibility. diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index ee7a2d8abc8e..2a72e1224e96 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -7,7 +7,7 @@ from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModelGradients, TensorType, \ TrainerConfigDict diff --git a/rllib/utils/annotations.py b/rllib/utils/annotations.py index de815b5ba311..2df9ef2eb7d2 100644 --- a/rllib/utils/annotations.py +++ b/rllib/utils/annotations.py @@ -1,7 +1,4 @@ -import inspect - -from ray.util import log_once -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated def override(cls): @@ -53,55 +50,20 @@ def DeveloperAPI(obj): return obj -def Deprecated(old=None, *, new=None, help=None, error): - """Annotation for documenting a (soon-to-be) deprecated method. +def ExperimentalAPI(obj): + """Annotation for documenting experimental APIs. + + Experimental APIs are classes and methods that are in development and may + change at any time in their development process. You should not expect + these APIs to be stable until their tag is changed to `DeveloperAPI` or + `PublicAPI`. - Methods tagged with this decorator should produce a - `ray.rllib.utils.deprecation.deprecation_warning(old=..., error=False)` - to not break existing code at this point. - In a next major release, this warning can then be made an error - (error=True), which means at this point that the method is already - no longer supported but will still inform the user about the - deprecation event. - In a further major release, the method should be erased. + Subclasses that inherit from a ``@ExperimentalAPI`` base class can be + assumed experimental as well. """ - def _inner(obj): - # A deprecated class. - if inspect.isclass(obj): - # Patch the class' init method to raise the warning/error. - obj_init = obj.__init__ - - def patched_init(*args, **kwargs): - if log_once(old or obj.__name__): - deprecation_warning( - old=old or obj.__name__, - new=new, - help=help, - error=error, - ) - return obj_init(*args, **kwargs) - - obj.__init__ = patched_init - # Return the patched class (with the warning/error when - # instantiated). - return obj - - # A deprecated class method or function. - # Patch with the warning/error at the beginning. - def _ctor(*args, **kwargs): - if log_once(old or obj.__name__): - deprecation_warning( - old=old or obj.__name__, - new=new, - help=help, - error=error, - ) - # Call the deprecated method/function. - return obj(*args, **kwargs) - - # Return the patched class method/function. - return _ctor - - # Return the prepared decorator. - return _inner + return obj + + +# Backward compatibility. +Deprecated = Deprecated diff --git a/rllib/utils/debug.py b/rllib/utils/debug.py index e6f769f0f04e..90d475cdf9e1 100644 --- a/rllib/utils/debug.py +++ b/rllib/utils/debug.py @@ -2,7 +2,7 @@ import os import pprint import random -from typing import Mapping, Optional +from typing import Any, Mapping, Optional from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -10,11 +10,17 @@ _printer = pprint.PrettyPrinter(indent=2, width=60) -def summarize(obj): +def summarize(obj: Any) -> Any: """Return a pretty-formatted string for an object. This has special handling for pretty-formatting of commonly used data types in RLlib, such as SampleBatch, numpy arrays, etc. + + Args: + obj: The object to format. + + Returns: + The summarized object. """ return _printer.pformat(_summarize(obj)) @@ -76,8 +82,8 @@ def update_global_seed_if_necessary(framework: Optional[str] = None, This is useful for debugging and testing. Args: - framework (Optional[str]): The framework specifier (may be None). - seed (Optional[int]): An optional int seed. If None, will not do + framework: The framework specifier (may be None). + seed: An optional int seed. If None, will not do anything. """ if seed is None: diff --git a/rllib/utils/deprecation.py b/rllib/utils/deprecation.py index 8cda88eaf9d8..ec4559f74aa2 100644 --- a/rllib/utils/deprecation.py +++ b/rllib/utils/deprecation.py @@ -1,6 +1,9 @@ +import inspect import logging from typing import Optional, Union +from ray.util import log_once + logger = logging.getLogger(__name__) # A constant to use for any configuration that should be deprecated @@ -23,8 +26,12 @@ def deprecation_warning( help (Optional[str]): An optional help text to tell the user, what to do instead of using `old`. error (Optional[Union[bool, Exception]]): Whether or which exception to - throw. If True, throw ValueError. If False, just warn. - If Exception, throw that Exception. + raise. If True, raise ValueError. If False, just warn. + If error is-a subclass of Exception, raise that Exception. + + Raises: + ValueError: If `error=True`. + Exception: Of type `error`, iff error is-a Exception subclass. """ msg = "`{}` has been deprecated.{}".format( old, (" Use `{}` instead.".format(new) if new else f" {help}" @@ -37,3 +44,57 @@ def deprecation_warning( else: logger.warning("DeprecationWarning: " + msg + " This will raise an error in the future!") + + +def Deprecated(old=None, *, new=None, help=None, error): + """Annotation for documenting a (soon-to-be) deprecated method. + + Methods tagged with this decorator should produce a + `ray.rllib.utils.deprecation.deprecation_warning(old=..., error=False)` + to not break existing code at this point. + In a next major release, this warning can then be made an error + (error=True), which means at this point that the method is already + no longer supported but will still inform the user about the + deprecation event. + In a further major release, the method should be erased. + """ + + def _inner(obj): + # A deprecated class. + if inspect.isclass(obj): + # Patch the class' init method to raise the warning/error. + obj_init = obj.__init__ + + def patched_init(*args, **kwargs): + if log_once(old or obj.__name__): + deprecation_warning( + old=old or obj.__name__, + new=new, + help=help, + error=error, + ) + return obj_init(*args, **kwargs) + + obj.__init__ = patched_init + # Return the patched class (with the warning/error when + # instantiated). + return obj + + # A deprecated class method or function. + # Patch with the warning/error at the beginning. + def _ctor(*args, **kwargs): + if log_once(old or obj.__name__): + deprecation_warning( + old=old or obj.__name__, + new=new, + help=help, + error=error, + ) + # Call the deprecated method/function. + return obj(*args, **kwargs) + + # Return the patched class method/function. + return _ctor + + # Return the prepared decorator. + return _inner diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index 9cad2d25443f..7b7c586aa294 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -17,7 +17,7 @@ from ray.rllib.utils.framework import try_import_tf, \ try_import_torch from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.tf_ops import get_placeholder, one_hot as tf_one_hot +from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot from ray.rllib.utils.torch_ops import one_hot from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index b6eb32b005db..8942e03a10a2 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -5,7 +5,8 @@ from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_torch, TensorType from ray.rllib.utils.typing import LocalOptimizer, TrainerConfigDict diff --git a/rllib/utils/exploration/gaussian_noise.py b/rllib/utils/exploration/gaussian_noise.py index 3c1972d1e5fd..e49a2ba12960 100644 --- a/rllib/utils/exploration/gaussian_noise.py +++ b/rllib/utils/exploration/gaussian_noise.py @@ -12,7 +12,7 @@ from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import Schedule from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py index ba7582903cf5..d8f9f8b39715 100644 --- a/rllib/utils/exploration/ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/ornstein_uhlenbeck_noise.py @@ -8,7 +8,7 @@ get_variable, TensorType from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.schedules import Schedule -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/random.py b/rllib/utils/exploration/random.py index d1d6c4d0ad98..a375e1e9deb6 100644 --- a/rllib/utils/exploration/random.py +++ b/rllib/utils/exploration/random.py @@ -12,7 +12,7 @@ TensorType from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/exploration/stochastic_sampling.py b/rllib/utils/exploration/stochastic_sampling.py index 593233625de1..4c9f53645832 100644 --- a/rllib/utils/exploration/stochastic_sampling.py +++ b/rllib/utils/exploration/stochastic_sampling.py @@ -10,7 +10,7 @@ from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import get_variable, try_import_tf, \ try_import_torch, TensorType -from ray.rllib.utils.tf_ops import zero_logps_from_actions +from ray.rllib.utils.tf_utils import zero_logps_from_actions tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 8057434d67e4..0670f832882b 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -4,20 +4,20 @@ import sys from typing import Any, Optional -from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TensorShape, TensorType logger = logging.getLogger(__name__) -def try_import_jax(error=False): +def try_import_jax(error: bool = False): """Tries importing JAX and FLAX and returns both modules (or Nones). Args: - error (bool): Whether to raise an error if JAX/FLAX cannot be imported. + error: Whether to raise an error if JAX/FLAX cannot be imported. Returns: - Tuple: The jax- and the flax modules. + Tuple containing the jax- and the flax modules. Raises: ImportError: If error=True and JAX is not installed. @@ -39,18 +39,17 @@ def try_import_jax(error=False): return jax, flax -def try_import_tf(error=False): +def try_import_tf(error: bool = False): """Tries importing tf and returns the module (or None). Args: - error (bool): Whether to raise an error if tf cannot be imported. + error: Whether to raise an error if tf cannot be imported. Returns: - Tuple: - - tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). - - tf module (resulting from `import tensorflow`). - Either tf1.x or 2.x. - - The actually installed tf version as int: 1 or 2. + Tuple containing + 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). + 2) tf module (resulting from `import tensorflow`). Either tf1.x or + 2.x. 3) The actually installed tf version as int: 1 or 2. Raises: ImportError: If error=True and tf is not installed. @@ -119,11 +118,11 @@ def decorator(func): return decorator -def try_import_tfp(error=False): +def try_import_tfp(error: bool = False): """Tries importing tfp and returns the module (or None). Args: - error (bool): Whether to raise an error if tfp cannot be imported. + error: Whether to raise an error if tfp cannot be imported. Returns: The tfp module. @@ -159,14 +158,14 @@ def __init__(self, *a, **kw): raise ImportError("Could not import `torch`.") -def try_import_torch(error=False): +def try_import_torch(error: bool = False): """Tries importing torch and returns the module (or None). Args: - error (bool): Whether to raise an error if torch cannot be imported. + error: Whether to raise an error if torch cannot be imported. Returns: - tuple: torch AND torch.nn modules. + Tuple consisting of the torch- AND torch.nn modules. Raises: ImportError: If error=True and PyTorch is not installed. @@ -201,7 +200,8 @@ def get_variable(value: Any, device: Optional[str] = None, shape: Optional[TensorShape] = None, dtype: Optional[TensorType] = None) -> Any: - """ + """Creates a tf variable, a torch tensor, or a python primitive. + Args: value: The initial value to use. In the non-tf case, this will be returned as is. In the tf case, this could be a tf-Initializer @@ -223,7 +223,7 @@ def get_variable(value: Any, Returns: A framework-specific variable (tf.Variable, torch.tensor, or - python primitive). + python primitive). """ if framework in ["tf2", "tf", "tfe"]: import tensorflow as tf @@ -258,8 +258,8 @@ def get_variable(value: Any, @Deprecated( - old="rllib/models/utils.py::get_activation_fn", - new="rllib/utils/framework.py::get_activation_fn", + old="rllib/utils/framework.py::get_activation_fn", + new="rllib/models/utils.py::get_activation_fn", error=False) def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): """Returns a framework specific activation function, given a name string. diff --git a/rllib/utils/memory.py b/rllib/utils/memory.py index c2989a407ce8..48248c602e5b 100644 --- a/rllib/utils/memory.py +++ b/rllib/utils/memory.py @@ -1,73 +1,8 @@ -import numpy as np - - -def aligned_array(size, dtype, align=64): - """Returns an array of a given size that is 64-byte aligned. - - The returned array can be efficiently copied into GPU memory by TensorFlow. - """ - - n = size * dtype.itemsize - empty = np.empty(n + (align - 1), dtype=np.uint8) - data_align = empty.ctypes.data % align - offset = 0 if data_align == 0 else (align - data_align) - if n == 0: - # stop np from optimising out empty slice reference - output = empty[offset:offset + 1][0:0].view(dtype) - else: - output = empty[offset:offset + n].view(dtype) - - assert len(output) == size, len(output) - assert output.ctypes.data % align == 0, output.ctypes.data - return output - - -def concat_aligned(items, time_major=None): - """Concatenate arrays, ensuring the output is 64-byte aligned. - - We only align float arrays; other arrays are concatenated as normal. - - This should be used instead of np.concatenate() to improve performance - when the output array is likely to be fed into TensorFlow. - - Args: - items (List(np.ndarray)): The list of items to concatenate and align. - time_major (bool): Whether the data in items is time-major, in which - case, we will concatenate along axis=1. - - Returns: - np.ndarray: The concat'd and aligned array. - """ - - if len(items) == 0: - return [] - elif len(items) == 1: - # we assume the input is aligned. In any case, it doesn't help - # performance to force align it since that incurs a needless copy. - return items[0] - elif (isinstance(items[0], np.ndarray) - and items[0].dtype in [np.float32, np.float64, np.uint8]): - dtype = items[0].dtype - flat = aligned_array(sum(s.size for s in items), dtype) - if time_major is not None: - if time_major is True: - batch_dim = sum(s.shape[1] for s in items) - new_shape = ( - items[0].shape[0], - batch_dim, - ) + items[0].shape[2:] - else: - batch_dim = sum(s.shape[0] for s in items) - new_shape = ( - batch_dim, - items[0].shape[1], - ) + items[0].shape[2:] - else: - batch_dim = sum(s.shape[0] for s in items) - new_shape = (batch_dim, ) + items[0].shape[1:] - output = flat.reshape(new_shape) - assert output.ctypes.data % 64 == 0, output.ctypes.data - np.concatenate(items, out=output, axis=1 if time_major else 0) - return output - else: - return np.concatenate(items, axis=1 if time_major else 0) +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.numpy import aligned_array, concat_aligned # noqa + +deprecation_warning( + old="ray.rllib.utils.memory.[...]", + new="ray.rllib.utils.numpy.[...]", + error=False, +) diff --git a/rllib/utils/metrics/window_stat.py b/rllib/utils/metrics/window_stat.py new file mode 100644 index 000000000000..9aa0d9f301df --- /dev/null +++ b/rllib/utils/metrics/window_stat.py @@ -0,0 +1,28 @@ +import numpy as np + + +class WindowStat: + def __init__(self, name, n): + self.name = name + self.items = [None] * n + self.idx = 0 + self.count = 0 + + def push(self, obj): + self.items[self.idx] = obj + self.idx += 1 + self.count += 1 + self.idx %= len(self.items) + + def stats(self): + if not self.count: + _quantiles = [] + else: + _quantiles = np.nanpercentile(self.items[:self.count], + [0, 10, 50, 90, 100]).tolist() + return { + self.name + "_count": int(self.count), + self.name + "_mean": float(np.nanmean(self.items[:self.count])), + self.name + "_std": float(np.nanstd(self.items[:self.count])), + self.name + "_quantiles": _quantiles, + } diff --git a/rllib/utils/multi_agent.py b/rllib/utils/multi_agent.py index 50d5227c54e7..82bf6b2089a1 100644 --- a/rllib/utils/multi_agent.py +++ b/rllib/utils/multi_agent.py @@ -11,12 +11,11 @@ def check_multi_agent(config: PartialTrainerConfigDict) -> \ """Checks, whether a (partial) config defines a multi-agent setup. Args: - config (PartialTrainerConfigDict): The user/Trainer/Policy config - to check for multi-agent. + config: The user/Trainer/Policy config to check for multi-agent. Returns: - The resulting (all fixed) multi-agent policy dict and whether we - have a multi-agent setup or not. + Tuple consisting of the resulting (all fixed) multi-agent policy + dict and bool indicating whether we have a multi-agent setup or not. """ multiagent_config = config["multiagent"] policies = multiagent_config.get("policies") diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index 2a77db1a61fc..91f56f259125 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -1,8 +1,9 @@ import numpy as np import tree # pip install dm_tree +from typing import List, Optional from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.typing import TensorType, Union +from ray.rllib.utils.typing import TensorType, TensorStructType, Union tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -17,146 +18,132 @@ MAX_LOG_NN_OUTPUT = 2 -def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return np.where( - np.abs(x) < delta, - np.power(x, 2.0) * 0.5, delta * (np.abs(x) - 0.5 * delta)) - - -def l2_loss(x): - """Computes half the L2 norm of a tensor (w/o the sqrt): sum(x**2) / 2 - - Args: - x (np.ndarray): The input tensor. +def aligned_array(size: int, dtype, align: int = 64) -> np.ndarray: + """Returns an array of a given size that is 64-byte aligned. - Returns: - The l2-loss output according to the above formula given `x`. - """ - return np.sum(np.square(x)) / 2.0 - - -def sigmoid(x, derivative=False): - """ - Returns the sigmoid function applied to x. - Alternatively, can return the derivative or the sigmoid function. + The returned array can be efficiently copied into GPU memory by TensorFlow. Args: - x (np.ndarray): The input to the sigmoid function. - derivative (bool): Whether to return the derivative or not. - Default: False. + size: The size (total number of items) of the array. For example, + array([[0.0, 1.0], [2.0, 3.0]]) would have size=4. + dtype: The numpy dtype of the array. + align: The alignment to use. Returns: - np.ndarray: The sigmoid function (or its derivative) applied to x. + A np.ndarray with the given specifications. """ - if derivative: - return x * (1 - x) + n = size * dtype.itemsize + empty = np.empty(n + (align - 1), dtype=np.uint8) + data_align = empty.ctypes.data % align + offset = 0 if data_align == 0 else (align - data_align) + if n == 0: + # stop np from optimising out empty slice reference + output = empty[offset:offset + 1][0:0].view(dtype) else: - return 1 / (1 + np.exp(-x)) + output = empty[offset:offset + n].view(dtype) + assert len(output) == size, len(output) + assert output.ctypes.data % align == 0, output.ctypes.data + return output -def softmax(x, axis=-1): - """ - Returns the softmax values for x as: - S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x. - Args: - x (np.ndarray): The input to the softmax function. - axis (int): The axis along which to softmax. +def concat_aligned(items: List[np.ndarray], + time_major: Optional[bool] = None) -> np.ndarray: + """Concatenate arrays, ensuring the output is 64-byte aligned. - Returns: - np.ndarray: The softmax over x. - """ - # x_exp = np.maximum(np.exp(x), SMALL_NUMBER) - x_exp = np.exp(x) - # return x_exp / - # np.maximum(np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) - return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) + We only align float arrays; other arrays are concatenated as normal. - -def relu(x, alpha=0.0): - """ - Implementation of the leaky ReLU function: - y = x * alpha if x < 0 else x + This should be used instead of np.concatenate() to improve performance + when the output array is likely to be fed into TensorFlow. Args: - x (np.ndarray): The input values. - alpha (float): A scaling ("leak") factor to use for negative x. + items: The list of items to concatenate and align. + time_major: Whether the data in items is time-major, in which + case, we will concatenate along axis=1. Returns: - np.ndarray: The leaky ReLU output for x. + The concat'd and aligned array. """ - return np.maximum(x, x * alpha, x) + if len(items) == 0: + return [] + elif len(items) == 1: + # we assume the input is aligned. In any case, it doesn't help + # performance to force align it since that incurs a needless copy. + return items[0] + elif (isinstance(items[0], np.ndarray) + and items[0].dtype in [np.float32, np.float64, np.uint8]): + dtype = items[0].dtype + flat = aligned_array(sum(s.size for s in items), dtype) + if time_major is not None: + if time_major is True: + batch_dim = sum(s.shape[1] for s in items) + new_shape = ( + items[0].shape[0], + batch_dim, + ) + items[0].shape[2:] + else: + batch_dim = sum(s.shape[0] for s in items) + new_shape = ( + batch_dim, + items[0].shape[1], + ) + items[0].shape[2:] + else: + batch_dim = sum(s.shape[0] for s in items) + new_shape = (batch_dim, ) + items[0].shape[1:] + output = flat.reshape(new_shape) + assert output.ctypes.data % 64 == 0, output.ctypes.data + np.concatenate(items, out=output, axis=1 if time_major else 0) + return output + else: + return np.concatenate(items, axis=1 if time_major else 0) -def one_hot(x: Union[TensorType, int], - depth: int = 0, - on_value: int = 1.0, - off_value: float = 0.0): - """ - One-hot utility function for numpy. - Thanks to qianyizhang: - https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30. + +def convert_to_numpy(x: TensorStructType, reduce_floats: bool = False): + """Converts values in `stats` to non-Tensor numpy or python types. Args: - x (TensorType): The input to be one-hot encoded. - depth (int): The max. number to be one-hot encoded (size of last rank). - on_value (float): The value to use for on. Default: 1.0. - off_value (float): The value to use for off. Default: 0.0. + x: Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all torch/tf tensors + being converted to numpy types. + reduce_floats: Whether to reduce all float64 data into float32 + automatically. Returns: - np.ndarray: The one-hot encoded equivalent of the input array. + A new struct with the same structure as `x`, but with all + values converted to numpy arrays (on CPU). """ - # Handle simple ints properly. - if isinstance(x, int): - x = np.array(x, dtype=np.int32) - # Handle torch arrays properly. - elif torch and isinstance(x, torch.Tensor): - x = x.numpy() - - # Handle bool arrays correctly. - if x.dtype == np.bool_: - x = x.astype(np.int) - depth = 2 - - # If depth is not given, try to infer it from the values in the array. - if depth == 0: - depth = np.max(x) + 1 - assert np.max(x) < depth, \ - "ERROR: The max. index of `x` ({}) is larger than depth ({})!".\ - format(np.max(x), depth) - shape = x.shape + # The mapping function used to numpyize torch/tf Tensors (and move them + # to the CPU beforehand). + def mapping(item): + if torch and isinstance(item, torch.Tensor): + ret = item.cpu().item() if len(item.size()) == 0 else \ + item.detach().cpu().numpy() + elif tf and isinstance(item, (tf.Tensor, tf.Variable)): + assert tf.executing_eagerly() + ret = item.numpy() + else: + ret = item + if reduce_floats and isinstance(ret, np.ndarray) and \ + ret.dtype == np.float64: + ret = ret.astype(np.float32) + return ret - # Python 2.7 compatibility, (*shape, depth) is not allowed. - shape_list = list(shape[:]) - shape_list.append(depth) - out = np.ones(shape_list) * off_value - indices = [] - for i in range(x.ndim): - tiles = [1] * x.ndim - s = [1] * x.ndim - s[i] = -1 - r = np.arange(shape[i]).reshape(s) - if i > 0: - tiles[i - 1] = shape[i - 1] - r = np.tile(r, tiles) - indices.append(r) - indices.append(x) - out[tuple(indices)] = on_value - return out + return tree.map_structure(mapping, x) -def fc(x, weights, biases=None, framework=None): - """ - Calculates the outputs of a fully-connected (dense) layer given - weights/biases and an input. +def fc(x: np.ndarray, + weights: np.ndarray, + biases: Optional[np.ndarray] = None, + framework: Optional[str] = None) -> np.ndarray: + """Calculates FC (dense) layer outputs given weights/biases and input. Args: - x (np.ndarray): The input to the dense layer. - weights (np.ndarray): The weights matrix. - biases (Optional[np.ndarray]): The biases vector. All 0s if None. - framework (Optional[str]): An optional framework hint (to figure out, + x: The input to the dense layer. + weights: The weights matrix. + biases: The biases vector. All 0s if None. + framework: An optional framework hint (to figure out, e.g. whether to transpose torch weight matrices). Returns: @@ -184,36 +171,48 @@ def map_(data, transpose=False): return np.matmul(x, weights) + (0.0 if biases is None else biases) -def lstm(x, - weights, - biases=None, - initial_internal_states=None, - time_major=False, - forget_bias=1.0): - """ - Calculates the outputs of an LSTM layer given weights/biases, - internal_states, and input. +def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray: + """Reference: https://en.wikipedia.org/wiki/Huber_loss.""" + return np.where( + np.abs(x) < delta, + np.power(x, 2.0) * 0.5, delta * (np.abs(x) - 0.5 * delta)) + + +def l2_loss(x: np.ndarray) -> np.ndarray: + """Computes half the L2 norm of a tensor (w/o the sqrt): sum(x**2) / 2. Args: - x (np.ndarray): The inputs to the LSTM layer including time-rank - (0th if time-major, else 1st) and the batch-rank - (1st if time-major, else 0th). + x: The input tensor. - weights (np.ndarray): The weights matrix. - biases (Optional[np.ndarray]): The biases vector. All 0s if None. + Returns: + The l2-loss output according to the above formula given `x`. + """ + return np.sum(np.square(x)) / 2.0 - initial_internal_states (Optional[np.ndarray]): The initial internal - states to pass into the layer. All 0s if None. - time_major (bool): Whether to use time-major or not. Default: False. +def lstm(x, + weights: np.ndarray, + biases: Optional[np.ndarray] = None, + initial_internal_states: Optional[np.ndarray] = None, + time_major: bool = False, + forget_bias: float = 1.0): + """Calculates LSTM layer output given weights/biases, states, and input. - forget_bias (float): Gets added to first sigmoid (forget gate) output. + Args: + x: The inputs to the LSTM layer including time-rank + (0th if time-major, else 1st) and the batch-rank + (1st if time-major, else 0th). + weights: The weights matrix. + biases: The biases vector. All 0s if None. + initial_internal_states: The initial internal + states to pass into the layer. All 0s if None. + time_major: Whether to use time-major or not. Default: False. + forget_bias: Gets added to first sigmoid (forget gate) output. Default: 1.0. Returns: - Tuple: - - The LSTM layer's output. - - Tuple: Last (c-state, h-state). + Tuple consisting of 1) The LSTM layer's output and + 2) Tuple: Last (c-state, h-state). """ sequence_length = x.shape[0 if time_major else 1] batch_size = x.shape[1 if time_major else 0] @@ -259,36 +258,113 @@ def lstm(x, return unrolled_outputs, (c_states, h_states) -# TODO: (sven) this will replace `TorchPolicy._convert_to_non_torch_tensor()`. -def convert_to_numpy(x, reduce_floats=False): - """Converts values in `stats` to non-Tensor numpy or python types. +def one_hot(x: Union[TensorType, int], + depth: int = 0, + on_value: int = 1.0, + off_value: float = 0.0) -> np.ndarray: + """One-hot utility function for numpy. + + Thanks to qianyizhang: + https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30. Args: - stats (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all torch/tf tensors - being converted to numpy types. - reduce_floats (bool): Whether to reduce all float64 data into float32 - automatically. + x: The input to be one-hot encoded. + depth: The max. number to be one-hot encoded (size of last rank). + on_value: The value to use for on. Default: 1.0. + off_value: The value to use for off. Default: 0.0. Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to numpy arrays (on CPU). + The one-hot encoded equivalent of the input array. """ - # The mapping function used to numpyize torch/tf Tensors (and move them - # to the CPU beforehand). - def mapping(item): - if torch and isinstance(item, torch.Tensor): - ret = item.cpu().item() if len(item.size()) == 0 else \ - item.detach().cpu().numpy() - elif tf and isinstance(item, (tf.Tensor, tf.Variable)): - assert tf.executing_eagerly() - ret = item.numpy() - else: - ret = item - if reduce_floats and isinstance(ret, np.ndarray) and \ - ret.dtype == np.float64: - ret = ret.astype(np.float32) - return ret + # Handle simple ints properly. + if isinstance(x, int): + x = np.array(x, dtype=np.int32) + # Handle torch arrays properly. + elif torch and isinstance(x, torch.Tensor): + x = x.numpy() - return tree.map_structure(mapping, x) + # Handle bool arrays correctly. + if x.dtype == np.bool_: + x = x.astype(np.int) + depth = 2 + + # If depth is not given, try to infer it from the values in the array. + if depth == 0: + depth = np.max(x) + 1 + assert np.max(x) < depth, \ + "ERROR: The max. index of `x` ({}) is larger than depth ({})!".\ + format(np.max(x), depth) + shape = x.shape + + # Python 2.7 compatibility, (*shape, depth) is not allowed. + shape_list = list(shape[:]) + shape_list.append(depth) + out = np.ones(shape_list) * off_value + indices = [] + for i in range(x.ndim): + tiles = [1] * x.ndim + s = [1] * x.ndim + s[i] = -1 + r = np.arange(shape[i]).reshape(s) + if i > 0: + tiles[i - 1] = shape[i - 1] + r = np.tile(r, tiles) + indices.append(r) + indices.append(x) + out[tuple(indices)] = on_value + return out + + +def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray: + """Implementation of the leaky ReLU function. + + y = x * alpha if x < 0 else x + + Args: + x: The input values. + alpha: A scaling ("leak") factor to use for negative x. + + Returns: + The leaky ReLU output for x. + """ + return np.maximum(x, x * alpha, x) + + +def sigmoid(x: np.ndarray, derivative: bool = False) -> np.ndarray: + """ + Returns the sigmoid function applied to x. + Alternatively, can return the derivative or the sigmoid function. + + Args: + x: The input to the sigmoid function. + derivative: Whether to return the derivative or not. + Default: False. + + Returns: + The sigmoid function (or its derivative) applied to x. + """ + if derivative: + return x * (1 - x) + else: + return 1 / (1 + np.exp(-x)) + + +def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: + """Returns the softmax values for x. + + The exact formula used is: + S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x. + + Args: + x: The input to the softmax function. + axis: The axis along which to softmax. + + Returns: + The softmax over x. + """ + # x_exp = np.maximum(np.exp(x), SMALL_NUMBER) + x_exp = np.exp(x) + # return x_exp / + # np.maximum(np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) + return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER) diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index a616fd9b4112..99baf070cd51 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -279,11 +279,11 @@ def check_compute_single_action(trainer, """Tests different combinations of args for trainer.compute_single_action. Args: - trainer (Trainer): The Trainer object to test. - include_state (bool): Whether to include the initial state of the - Policy's Model in the `compute_single_action` call. - include_prev_action_reward (bool): Whether to include the prev-action - and -reward in the `compute_single_action` call. + trainer: The Trainer object to test. + include_state: Whether to include the initial state of the Policy's + Model in the `compute_single_action` call. + include_prev_action_reward: Whether to include the prev-action and + -reward in the `compute_single_action` call. Raises: ValueError: If anything unexpected happens. diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index 1b577be7ef72..bdef1e5a4710 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -1,305 +1,8 @@ -import gym -from gym.spaces import Discrete, MultiDiscrete -import numpy as np -import tree # pip install dm_tree - -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.typing import TensorStructType, TensorType - -tf1, tf, tfv = try_import_tf() - - -def convert_to_non_tf_type(stats): - """Converts values in `stats` to non-Tensor numpy or python types. - - Args: - stats (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all tf (eager) tensors - being converted to numpy types. - - Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to non-tf Tensor types. - """ - - # The mapping function used to numpyize torch Tensors. - def mapping(item): - if isinstance(item, (tf.Tensor, tf.Variable)): - return item.numpy() - else: - return item - - return tree.map_structure(mapping, stats) - - -def explained_variance(y, pred): - _, y_var = tf.nn.moments(y, axes=[0]) - _, diff_var = tf.nn.moments(y - pred, axes=[0]) - return tf.maximum(-1.0, 1 - (diff_var / y_var)) - - -def get_gpu_devices(): - """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"]. - - Supports both tf1.x and tf2.x. - """ - if tfv == 1: - from tensorflow.python.client import device_lib - devices = device_lib.list_local_devices() - else: - try: - devices = tf.config.list_physical_devices() - except Exception: - devices = tf.config.experimental.list_physical_devices() - - # Expect "GPU", but also stuff like: "XLA_GPU". - return [d.name for d in devices if "GPU" in d.device_type] - - -def get_placeholder(*, - space=None, - value=None, - name=None, - time_axis=False, - flatten=True): - from ray.rllib.models.catalog import ModelCatalog - - if space is not None: - if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): - if flatten: - return ModelCatalog.get_action_placeholder(space, None) - else: - return tree.map_structure_with_path( - lambda path, component: get_placeholder( - space=component, - name=name + "." + ".".join([str(p) for p in path]), - ), - get_base_struct_from_space(space), - ) - return tf1.placeholder( - shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, - dtype=tf.float32 if space.dtype == np.float64 else space.dtype, - name=name, - ) - else: - assert value is not None - shape = value.shape[1:] - return tf1.placeholder( - shape=(None, ) + ((None, ) - if time_axis else ()) + (shape if isinstance( - shape, tuple) else tuple(shape.as_list())), - dtype=tf.float32 if value.dtype == np.float64 else value.dtype, - name=name, - ) - - -def get_tf_eager_cls_if_necessary(orig_cls, config): - cls = orig_cls - framework = config.get("framework", "tf") - if framework in ["tf2", "tf", "tfe"]: - if not tf1: - raise ImportError("Could not import tensorflow!") - if framework in ["tf2", "tfe"]: - assert tf1.executing_eagerly() - - from ray.rllib.policy.tf_policy import TFPolicy - - # Create eager-class. - if hasattr(orig_cls, "as_eager"): - cls = orig_cls.as_eager() - if config.get("eager_tracing"): - cls = cls.with_tracing() - # Could be some other type of policy. - elif not issubclass(orig_cls, TFPolicy): - pass - else: - raise ValueError("This policy does not support eager " - "execution: {}".format(orig_cls)) - return cls - - -def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return tf.where( - tf.abs(x) < delta, - tf.math.square(x) * 0.5, delta * (tf.abs(x) - 0.5 * delta)) - - -def zero_logps_from_actions(actions: TensorStructType) -> TensorType: - """Helper function useful for returning dummy logp's (0) for some actions. - - Args: - actions (TensorStructType): The input actions. This can be any struct - of complex action components or a simple tensor of different - dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}. - - Returns: - TensorType: A 1D tensor of 0.0 (dummy logp's) matching the batch - dim of `actions` (shape=[B]). - """ - # Need to flatten `actions` in case we have a complex action space. - # Take the 0th component to extract the batch dim. - action_component = tree.flatten(actions)[0] - logp_ = tf.zeros_like(action_component, dtype=tf.float32) - # Logp's should be single values (but with the same batch dim as - # `deterministic_actions` or `stochastic_actions`). In case - # actions are just [B], zeros_like works just fine here, but if - # actions are [B, ...], we have to reduce logp back to just [B]. - while len(logp_.shape) > 1: - logp_ = logp_[:, 0] - return logp_ - - -def one_hot(x, space): - if isinstance(space, Discrete): - return tf.one_hot(x, space.n, dtype=tf.float32) - elif isinstance(space, MultiDiscrete): - return tf.concat( - [ - tf.one_hot(x[:, i], n, dtype=tf.float32) - for i, n in enumerate(space.nvec) - ], - axis=-1) - else: - raise ValueError("Unsupported space for `one_hot`: {}".format(space)) - - -def reduce_mean_ignore_inf(x, axis): - """Same as tf.reduce_mean() but ignores -inf values.""" - mask = tf.not_equal(x, tf.float32.min) - x_zeroed = tf.where(mask, x, tf.zeros_like(x)) - return (tf.reduce_sum(x_zeroed, axis) / tf.reduce_sum( - tf.cast(mask, tf.float32), axis)) - - -def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0): - """Minimized `objective` using `optimizer` w.r.t. variables in - `var_list` while ensure the norm of the gradients for each - variable is clipped to `clip_val` - """ - # Accidentally passing values < 0.0 will break all gradients. - assert clip_val is None or clip_val > 0.0, clip_val - - if tf.executing_eagerly(): - tape = optimizer.tape - grads_and_vars = list( - zip(list(tape.gradient(objective, var_list)), var_list)) - else: - grads_and_vars = optimizer.compute_gradients( - objective, var_list=var_list) - - return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v) - for (g, v) in grads_and_vars if g is not None] - - -def make_tf_callable(session_or_none, dynamic_shape=False): - """Returns a function that can be executed in either graph or eager mode. - - The function must take only positional args. - - If eager is enabled, this will act as just a function. Otherwise, it - will build a function that executes a session run with placeholders - internally. - - Args: - session_or_none (tf.Session): tf.Session if in graph mode, else None. - dynamic_shape (bool): True if the placeholders should have a dynamic - batch dimension. Otherwise they will be fixed shape. - - Returns: - a Python function that can be called in either mode. - """ - - if tf.executing_eagerly(): - assert session_or_none is None - else: - assert session_or_none is not None - - def make_wrapper(fn): - # Static-graph mode: Create placeholders and make a session call each - # time the wrapped function is called. Returns the output of this - # session call. - if session_or_none is not None: - args_placeholders = [] - kwargs_placeholders = {} - - symbolic_out = [None] - - def call(*args, **kwargs): - args_flat = [] - for a in args: - if type(a) is list: - args_flat.extend(a) - else: - args_flat.append(a) - args = args_flat - - # We have not built any placeholders yet: Do this once here, - # then reuse the same placeholders each time we call this - # function again. - if symbolic_out[0] is None: - with session_or_none.graph.as_default(): - - def _create_placeholders(path, value): - if dynamic_shape: - if len(value.shape) > 0: - shape = (None, ) + value.shape[1:] - else: - shape = () - else: - shape = value.shape - return tf1.placeholder( - dtype=value.dtype, - shape=shape, - name=".".join([str(p) for p in path]), - ) - - placeholders = tree.map_structure_with_path( - _create_placeholders, args) - for ph in tree.flatten(placeholders): - args_placeholders.append(ph) - - placeholders = tree.map_structure_with_path( - _create_placeholders, kwargs) - for k, ph in placeholders.items(): - kwargs_placeholders[k] = ph - - symbolic_out[0] = fn(*args_placeholders, - **kwargs_placeholders) - feed_dict = dict(zip(args_placeholders, tree.flatten(args))) - tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), - kwargs_placeholders, kwargs) - ret = session_or_none.run(symbolic_out[0], feed_dict) - return ret - - return call - # Eager mode (call function as is). - else: - return fn - - return make_wrapper - - -def scope_vars(scope, trainable_only=False): - """ - Get variables inside a scope - The scope can be specified as a string - - Parameters - ---------- - scope: str or VariableScope - scope in which the variables reside. - trainable_only: bool - whether or not to return only the variables that were marked as - trainable. - - Returns - ------- - vars: [tf.Variable] - list of variables in `scope`. - """ - return tf1.get_collection( - tf1.GraphKeys.TRAINABLE_VARIABLES - if trainable_only else tf1.GraphKeys.VARIABLES, - scope=scope if isinstance(scope, str) else scope.name) +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.tf_utils import * # noqa + +deprecation_warning( + old="ray.rllib.utils.tf_ops.[...]", + new="ray.rllib.utils.tf_utils.[...]", + error=False, +) diff --git a/rllib/utils/tf_utils.py b/rllib/utils/tf_utils.py new file mode 100644 index 000000000000..7af39988b910 --- /dev/null +++ b/rllib/utils/tf_utils.py @@ -0,0 +1,426 @@ +import gym +from gym.spaces import Discrete, MultiDiscrete +import numpy as np +import tree # pip install dm_tree +from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING, Union + +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \ + PartialTrainerConfigDict, TensorStructType, TensorType + +if TYPE_CHECKING: + from ray.rllib.policy.tf_policy import TFPolicy + +tf1, tf, tfv = try_import_tf() + + +@Deprecated(new="ray.rllib.utils.numpy.convert_to_numpy()", error=True) +def convert_to_non_tf_type(x: TensorStructType) -> TensorStructType: + """Converts values in `stats` to non-Tensor numpy or python types. + + Args: + x: Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all tf (eager) tensors + being converted to numpy types. + + Returns: + A new struct with the same structure as `x`, but with all + values converted to non-tf Tensor types. + """ + + # The mapping function used to numpyize torch Tensors. + def mapping(item): + if isinstance(item, (tf.Tensor, tf.Variable)): + return item.numpy() + else: + return item + + return tree.map_structure(mapping, x) + + +def explained_variance(y: TensorType, pred: TensorType) -> TensorType: + """Computes the explained variance for a pair of labels and predictions. + + The formula used is: + max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2)) + + Args: + y: The labels. + pred: The predictions. + + Returns: + The explained variance given a pair of labels and predictions. + """ + _, y_var = tf.nn.moments(y, axes=[0]) + _, diff_var = tf.nn.moments(y - pred, axes=[0]) + return tf.maximum(-1.0, 1 - (diff_var / y_var)) + + +def get_gpu_devices() -> List[str]: + """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"]. + + Supports both tf1.x and tf2.x. + + Returns: + List of GPU device names (str). + """ + if tfv == 1: + from tensorflow.python.client import device_lib + devices = device_lib.list_local_devices() + else: + try: + devices = tf.config.list_physical_devices() + except Exception: + devices = tf.config.experimental.list_physical_devices() + + # Expect "GPU", but also stuff like: "XLA_GPU". + return [d.name for d in devices if "GPU" in d.device_type] + + +def get_placeholder(*, + space: Optional[gym.Space] = None, + value: Optional[Any] = None, + name: Optional[str] = None, + time_axis: bool = False, + flatten: bool = True) -> "tf1.placeholder": + """Returns a tf1.placeholder object given optional hints, such as a space. + + Note that the returned placeholder will always have a leading batch + dimension (None). + + Args: + space: An optional gym.Space to hint the shape and dtype of the + placeholder. + value: An optional value to hint the shape and dtype of the + placeholder. + name: An optional name for the placeholder. + time_axis: Whether the placeholder should also receive a time + dimension (None). + flatten: Whether to flatten the given space into a plain Box space + and then create the placeholder from the resulting space. + + Returns: + The tf1 placeholder. + """ + from ray.rllib.models.catalog import ModelCatalog + + if space is not None: + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): + if flatten: + return ModelCatalog.get_action_placeholder(space, None) + else: + return tree.map_structure_with_path( + lambda path, component: get_placeholder( + space=component, + name=name + "." + ".".join([str(p) for p in path]), + ), + get_base_struct_from_space(space), + ) + return tf1.placeholder( + shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, + dtype=tf.float32 if space.dtype == np.float64 else space.dtype, + name=name, + ) + else: + assert value is not None + shape = value.shape[1:] + return tf1.placeholder( + shape=(None, ) + ((None, ) + if time_axis else ()) + (shape if isinstance( + shape, tuple) else tuple(shape.as_list())), + dtype=tf.float32 if value.dtype == np.float64 else value.dtype, + name=name, + ) + + +def get_tf_eager_cls_if_necessary( + orig_cls: Type["TFPolicy"], + config: PartialTrainerConfigDict) -> Type["TFPolicy"]: + """Returns the corresponding tf-eager class for a given TFPolicy class. + + Args: + orig_cls: The original TFPolicy class to get the corresponding tf-eager + class for. + config: The Trainer config dict. + + Returns: + The tf eager policy class corresponding to the given TFPolicy class. + """ + cls = orig_cls + framework = config.get("framework", "tf") + if framework in ["tf2", "tf", "tfe"]: + if not tf1: + raise ImportError("Could not import tensorflow!") + if framework in ["tf2", "tfe"]: + assert tf1.executing_eagerly() + + from ray.rllib.policy.tf_policy import TFPolicy + + # Create eager-class. + if hasattr(orig_cls, "as_eager"): + cls = orig_cls.as_eager() + if config.get("eager_tracing"): + cls = cls.with_tracing() + # Could be some other type of policy. + elif not issubclass(orig_cls, TFPolicy): + pass + else: + raise ValueError("This policy does not support eager " + "execution: {}".format(orig_cls)) + return cls + + +def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: + """Computes the huber loss for a given term and delta parameter. + + Reference: https://en.wikipedia.org/wiki/Huber_loss + Note that the factor of 0.5 is implicitly included in the calculation. + + Formula: + L = 0.5 * x^2 for small abs x (delta threshold) + L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold) + + Args: + x: The input term, e.g. a TD error. + delta: The delta parmameter in the above formula. + + Returns: + The Huber loss resulting from `x` and `delta`. + """ + return tf.where( + tf.abs(x) < delta, # for small x -> apply the Huber correction + tf.math.square(x) * 0.5, + delta * (tf.abs(x) - 0.5 * delta), + ) + + +def make_tf_callable(session_or_none: Optional["tf1.Session"], + dynamic_shape: bool = False) -> Callable: + """Returns a function that can be executed in either graph or eager mode. + + The function must take only positional args. + + If eager is enabled, this will act as just a function. Otherwise, it + will build a function that executes a session run with placeholders + internally. + + Args: + session_or_none: tf.Session if in graph mode, else None. + dynamic_shape: True if the placeholders should have a dynamic + batch dimension. Otherwise they will be fixed shape. + + Returns: + A function that can be called in either eager or static-graph mode. + """ + + if tf.executing_eagerly(): + assert session_or_none is None + else: + assert session_or_none is not None + + def make_wrapper(fn): + # Static-graph mode: Create placeholders and make a session call each + # time the wrapped function is called. Returns the output of this + # session call. + if session_or_none is not None: + args_placeholders = [] + kwargs_placeholders = {} + + symbolic_out = [None] + + def call(*args, **kwargs): + args_flat = [] + for a in args: + if type(a) is list: + args_flat.extend(a) + else: + args_flat.append(a) + args = args_flat + + # We have not built any placeholders yet: Do this once here, + # then reuse the same placeholders each time we call this + # function again. + if symbolic_out[0] is None: + with session_or_none.graph.as_default(): + + def _create_placeholders(path, value): + if dynamic_shape: + if len(value.shape) > 0: + shape = (None, ) + value.shape[1:] + else: + shape = () + else: + shape = value.shape + return tf1.placeholder( + dtype=value.dtype, + shape=shape, + name=".".join([str(p) for p in path]), + ) + + placeholders = tree.map_structure_with_path( + _create_placeholders, args) + for ph in tree.flatten(placeholders): + args_placeholders.append(ph) + + placeholders = tree.map_structure_with_path( + _create_placeholders, kwargs) + for k, ph in placeholders.items(): + kwargs_placeholders[k] = ph + + symbolic_out[0] = fn(*args_placeholders, + **kwargs_placeholders) + feed_dict = dict(zip(args_placeholders, tree.flatten(args))) + tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v), + kwargs_placeholders, kwargs) + ret = session_or_none.run(symbolic_out[0], feed_dict) + return ret + + return call + # Eager mode (call function as is). + else: + return fn + + return make_wrapper + + +def minimize_and_clip( + optimizer: LocalOptimizer, + objective: TensorType, + var_list: List["tf.Variable"], + clip_val: float = 10.0, +) -> ModelGradients: + """Computes, then clips gradients using objective, optimizer and var list. + + Ensures the norm of the gradients for each variable is clipped to + `clip_val`. + + Args: + optimizer: Either a shim optimizer (tf eager) containing a + tf.GradientTape under `self.tape` or a tf1 local optimizer + object. + objective: The loss tensor to calculate gradients on. + var_list: The list of tf.Variables to compute gradients over. + clip_val: The global norm clip value. Will clip around -clip_val and + +clip_val. + + Returns: + The resulting model gradients (list or tuples of grads + vars) + corresponding to the input `var_list`. + """ + # Accidentally passing values < 0.0 will break all gradients. + assert clip_val is None or clip_val > 0.0, clip_val + + if tf.executing_eagerly(): + tape = optimizer.tape + grads_and_vars = list( + zip(list(tape.gradient(objective, var_list)), var_list)) + else: + grads_and_vars = optimizer.compute_gradients( + objective, var_list=var_list) + + return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v) + for (g, v) in grads_and_vars if g is not None] + + +def one_hot(x: TensorType, space: gym.Space) -> TensorType: + """Returns a one-hot tensor, given and int tensor and a space. + + Handles the MultiDiscrete case as well. + + Args: + x: The input tensor. + space: The space to use for generating the one-hot tensor. + + Returns: + The resulting one-hot tensor. + + Raises: + ValueError: If the given space is not a discrete one. + + Examples: + >>> x = tf.Variable([0, 3], dtype=tf.int32) # batch-dim=2 + >>> # Discrete space with 4 (one-hot) slots per batch item. + >>> s = gym.spaces.Discrete(4) + >>> one_hot(x, s) + + + >>> x = tf.Variable([[0, 1, 2, 3]], dtype=tf.int32) # batch-dim=1 + >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots + >>> # per batch item. + >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) + >>> one_hot(x, s) + + """ + if isinstance(space, Discrete): + return tf.one_hot(x, space.n, dtype=tf.float32) + elif isinstance(space, MultiDiscrete): + return tf.concat( + [ + tf.one_hot(x[:, i], n, dtype=tf.float32) + for i, n in enumerate(space.nvec) + ], + axis=-1) + else: + raise ValueError("Unsupported space for `one_hot`: {}".format(space)) + + +def reduce_mean_ignore_inf(x: TensorType, + axis: Optional[int] = None) -> TensorType: + """Same as tf.reduce_mean() but ignores -inf values. + + Args: + x: The input tensor to reduce mean over. + axis: The axis over which to reduce. None for all axes. + + Returns: + The mean reduced inputs, ignoring inf values. + """ + mask = tf.not_equal(x, tf.float32.min) + x_zeroed = tf.where(mask, x, tf.zeros_like(x)) + return (tf.math.reduce_sum(x_zeroed, axis) / tf.math.reduce_sum( + tf.cast(mask, tf.float32), axis)) + + +def scope_vars(scope: Union[str, "tf1.VariableScope"], + trainable_only: bool = False) -> List["tf.Variable"]: + """Get variables inside a given scope. + + Args: + scope: Scope in which the variables reside. + trainable_only: Whether or not to return only the variables that were + marked as trainable. + + Returns: + The list of variables in the given `scope`. + """ + return tf1.get_collection( + tf1.GraphKeys.TRAINABLE_VARIABLES + if trainable_only else tf1.GraphKeys.VARIABLES, + scope=scope if isinstance(scope, str) else scope.name) + + +def zero_logps_from_actions(actions: TensorStructType) -> TensorType: + """Helper function useful for returning dummy logp's (0) for some actions. + + Args: + actions: The input actions. This can be any struct + of complex action components or a simple tensor of different + dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}. + + Returns: + A 1D tensor of 0.0 (dummy logp's) matching the batch + dim of `actions` (shape=[B]). + """ + # Need to flatten `actions` in case we have a complex action space. + # Take the 0th component to extract the batch dim. + action_component = tree.flatten(actions)[0] + logp_ = tf.zeros_like(action_component, dtype=tf.float32) + # Logp's should be single values (but with the same batch dim as + # `deterministic_actions` or `stochastic_actions`). In case + # actions are just [B], zeros_like works just fine here, but if + # actions are [B, ...], we have to reduce logp back to just [B]. + while len(logp_.shape) > 1: + logp_ = logp_[:, 0] + return logp_ diff --git a/rllib/utils/threading.py b/rllib/utils/threading.py index a75f1d65c306..f6a3f7b4fa67 100644 --- a/rllib/utils/threading.py +++ b/rllib/utils/threading.py @@ -1,7 +1,7 @@ from typing import Callable -def with_lock(func: Callable): +def with_lock(func: Callable) -> Callable: """Use as decorator (@withlock) around object methods that need locking. Note: The object must have a self._lock = threading.Lock() property. @@ -9,10 +9,10 @@ def with_lock(func: Callable): object can be called asynchronously). Args: - func (Callable): The function to decorate/wrap. + func: The function to decorate/wrap. Returns: - Callable: The wrapped (object-level locked) function. + The wrapped (object-level locked) function. """ def wrapper(self, *a, **k): diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 90ccc64aad12..8bc9eebbd116 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -1,255 +1,8 @@ -from gym.spaces import Discrete, MultiDiscrete -import numpy as np -import os -import tree # pip install dm_tree -import warnings - -from ray.rllib.models.repeated_values import RepeatedValues -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.numpy import SMALL_NUMBER - -torch, nn = try_import_torch() - -# Limit values suitable for use as close to a -inf logit. These are useful -# since -inf / inf cause NaNs during backprop. -FLOAT_MIN = -3.4e38 -FLOAT_MAX = 3.4e38 - - -def apply_grad_clipping(policy, optimizer, loss): - """Applies gradient clipping to already computed grads inside `optimizer`. - - Args: - policy (TorchPolicy): The TorchPolicy, which calculated `loss`. - optimizer (torch.optim.Optimizer): A local torch optimizer object. - loss (torch.Tensor): The torch loss tensor. - """ - info = {} - if policy.config["grad_clip"]: - for param_group in optimizer.param_groups: - # Make sure we only pass params with grad != None into torch - # clip_grad_norm_. Would fail otherwise. - params = list( - filter(lambda p: p.grad is not None, param_group["params"])) - if params: - grad_gnorm = nn.utils.clip_grad_norm_( - params, policy.config["grad_clip"]) - if isinstance(grad_gnorm, torch.Tensor): - grad_gnorm = grad_gnorm.cpu().numpy() - info["grad_gnorm"] = grad_gnorm - return info - - -def atanh(x): - return 0.5 * torch.log( - (1 + x).clamp(min=SMALL_NUMBER) / (1 - x).clamp(min=SMALL_NUMBER)) - - -def concat_multi_gpu_td_errors(policy): - td_error = torch.cat( - [ - t.tower_stats.get("td_error", torch.tensor([0.0])).to( - policy.device) for t in policy.model_gpu_towers - ], - dim=0) - policy.td_error = td_error - return { - "td_error": td_error, - "mean_td_error": torch.mean(td_error), - } - - -def convert_to_non_torch_type(stats): - """Converts values in `stats` to non-Tensor numpy or python types. - - Args: - stats (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all torch tensors - being converted to numpy types. - - Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to non-torch Tensor types. - """ - - # The mapping function used to numpyize torch Tensors. - def mapping(item): - if isinstance(item, torch.Tensor): - return item.cpu().item() if len(item.size()) == 0 else \ - item.detach().cpu().numpy() - else: - return item - - return tree.map_structure(mapping, stats) - - -def convert_to_torch_tensor(x, device=None): - """Converts any struct to torch.Tensors. - - x (any): Any (possibly nested) struct, the values in which will be - converted and returned as a new struct with all leaves converted - to torch tensors. - - Returns: - Any: A new struct with the same structure as `stats`, but with all - values converted to torch Tensor types. - """ - - def mapping(item): - # Already torch tensor -> make sure it's on right device. - if torch.is_tensor(item): - return item if device is None else item.to(device) - # Special handling of "Repeated" values. - elif isinstance(item, RepeatedValues): - return RepeatedValues( - tree.map_structure(mapping, item.values), item.lengths, - item.max_len) - # Numpy arrays. - if isinstance(item, np.ndarray): - # np.object_ type (e.g. info dicts in train batch): leave as-is. - if item.dtype == np.object_: - return item - # Non-writable numpy-arrays will cause PyTorch warning. - elif item.flags.writeable is False: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - tensor = torch.from_numpy(item) - # Already numpy: Wrap as torch tensor. - else: - tensor = torch.from_numpy(item) - # Everything else: Convert to numpy, then wrap as torch tensor. - else: - tensor = torch.from_numpy(np.asarray(item)) - # Floatify all float64 tensors. - if tensor.dtype == torch.double: - tensor = tensor.float() - return tensor if device is None else tensor.to(device) - - return tree.map_structure(mapping, x) - - -def explained_variance(y, pred): - y_var = torch.var(y, dim=[0]) - diff_var = torch.var(y - pred, dim=[0]) - min_ = torch.tensor([-1.0]).to(pred.device) - return torch.max(min_, 1 - (diff_var / y_var))[0] - - -def global_norm(tensors): - """Returns the global L2 norm over a list of tensors. - - output = sqrt(SUM(t ** 2 for t in tensors)), - where SUM reduces over all tensors and over all elements in tensors. - - Args: - tensors (List[torch.Tensor]): The list of tensors to calculate the - global norm over. - """ - # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor. - single_l2s = [ - torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors - ] - # Compute global norm from all single tensors' L2 norms. - return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) - - -def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return torch.where( - torch.abs(x) < delta, - torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta)) - - -def l2_loss(x): - """Computes half the L2 norm of a tensor without the sqrt. - - output = sum(x ** 2) / 2 - """ - return torch.sum(torch.pow(x, 2.0)) / 2.0 - - -def minimize_and_clip(optimizer, clip_val=10): - """Clips gradients found in `optimizer.param_groups` to given value. - - Ensures the norm of the gradients for each variable is clipped to - `clip_val` - """ - for param_group in optimizer.param_groups: - for p in param_group["params"]: - if p.grad is not None: - torch.nn.utils.clip_grad_norm_(p.grad, clip_val) - - -def one_hot(x, space): - if isinstance(space, Discrete): - return nn.functional.one_hot(x.long(), space.n) - elif isinstance(space, MultiDiscrete): - return torch.cat( - [ - nn.functional.one_hot(x[:, i].long(), n) - for i, n in enumerate(space.nvec) - ], - dim=-1) - else: - raise ValueError("Unsupported space for `one_hot`: {}".format(space)) - - -def reduce_mean_ignore_inf(x, axis): - """Same as torch.mean() but ignores -inf values.""" - mask = torch.ne(x, float("-inf")) - x_zeroed = torch.where(mask, x, torch.zeros_like(x)) - return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis) - - -def sequence_mask(lengths, maxlen=None, dtype=None, time_major=False): - """Offers same behavior as tf.sequence_mask for torch. - - Thanks to Dimitris Papatheodorou - (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ - 39036). - """ - if maxlen is None: - maxlen = int(lengths.max()) - - mask = ~(torch.ones( - (len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths) - if not time_major: - mask = mask.t() - mask.type(dtype or torch.bool) - - return mask - - -def set_torch_seed(seed): - if seed is not None and torch: - torch.manual_seed(seed) - # See https://github.com/pytorch/pytorch/issues/47672. - cuda_version = torch.version.cuda - if cuda_version is not None and float(torch.version.cuda) >= 10.2: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8" - else: - # Not all Operations support this. - torch.use_deterministic_algorithms(True) - # This is only for Convolution no problem. - torch.backends.cudnn.deterministic = True - - -def softmax_cross_entropy_with_logits(logits, labels): - """Same behavior as tf.nn.softmax_cross_entropy_with_logits. - - Args: - x (TensorType): - - Returns: - - """ - return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1) - - -class Swish(nn.Module): - def __init__(self): - super().__init__() - self._beta = nn.Parameter(torch.tensor(1.0)) - - def forward(self, input_tensor): - return input_tensor * torch.sigmoid(self._beta * input_tensor) +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.torch_utils import * # noqa + +deprecation_warning( + old="ray.rllib.utils.torch_ops.[...]", + new="ray.rllib.utils.torch_utils.[...]", + error=False, +) diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py new file mode 100644 index 000000000000..19ea8ded7cfc --- /dev/null +++ b/rllib/utils/torch_utils.py @@ -0,0 +1,395 @@ +import gym +from gym.spaces import Discrete, MultiDiscrete +import numpy as np +import os +import tree # pip install dm_tree +from typing import Dict, List, Optional, TYPE_CHECKING +import warnings + +from ray.rllib.models.repeated_values import RepeatedValues +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import SMALL_NUMBER +from ray.rllib.utils.typing import LocalOptimizer, TensorType, TensorStructType + +if TYPE_CHECKING: + from ray.rllib.policy.torch_policy import TorchPolicy + +torch, nn = try_import_torch() + +# Limit values suitable for use as close to a -inf logit. These are useful +# since -inf / inf cause NaNs during backprop. +FLOAT_MIN = -3.4e38 +FLOAT_MAX = 3.4e38 + + +def apply_grad_clipping(policy: "TorchPolicy", optimizer: LocalOptimizer, + loss: TensorType) -> Dict[str, TensorType]: + """Applies gradient clipping to already computed grads inside `optimizer`. + + Args: + policy: The TorchPolicy, which calculated `loss`. + optimizer: A local torch optimizer object. + loss: The torch loss tensor. + + Returns: + An info dict containing the "grad_norm" key and the resulting clipped + gradients. + """ + info = {} + if policy.config["grad_clip"]: + for param_group in optimizer.param_groups: + # Make sure we only pass params with grad != None into torch + # clip_grad_norm_. Would fail otherwise. + params = list( + filter(lambda p: p.grad is not None, param_group["params"])) + if params: + grad_gnorm = nn.utils.clip_grad_norm_( + params, policy.config["grad_clip"]) + if isinstance(grad_gnorm, torch.Tensor): + grad_gnorm = grad_gnorm.cpu().numpy() + info["grad_gnorm"] = grad_gnorm + return info + + +@Deprecated( + old="ray.rllib.utils.torch_utils.atanh", + new="torch.math.atanh", + error=False) +def atanh(x: TensorType) -> TensorType: + """Atanh function for PyTorch.""" + return 0.5 * torch.log( + (1 + x).clamp(min=SMALL_NUMBER) / (1 - x).clamp(min=SMALL_NUMBER)) + + +def concat_multi_gpu_td_errors(policy: "TorchPolicy") -> Dict[str, TensorType]: + """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy. + + TD-errors are extracted from the TorchPolicy via its tower_stats property. + + Args: + policy: The TorchPolicy to extract the TD-error values from. + + Returns: + A dict mapping strings "td_error" and "mean_td_error" to the + corresponding concatenated and mean-reduced values. + """ + td_error = torch.cat( + [ + t.tower_stats.get("td_error", torch.tensor([0.0])).to( + policy.device) for t in policy.model_gpu_towers + ], + dim=0) + policy.td_error = td_error + return { + "td_error": td_error, + "mean_td_error": torch.mean(td_error), + } + + +@Deprecated(new="ray/rllib/utils/numpy.py::convert_to_numpy", error=False) +def convert_to_non_torch_type(stats: TensorStructType) -> TensorStructType: + """Converts values in `stats` to non-Tensor numpy or python types. + + Args: + stats (any): Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all torch tensors + being converted to numpy types. + + Returns: + Any: A new struct with the same structure as `stats`, but with all + values converted to non-torch Tensor types. + """ + + # The mapping function used to numpyize torch Tensors. + def mapping(item): + if isinstance(item, torch.Tensor): + return item.cpu().item() if len(item.size()) == 0 else \ + item.detach().cpu().numpy() + else: + return item + + return tree.map_structure(mapping, stats) + + +def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None): + """Converts any struct to torch.Tensors. + + x (any): Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all leaves converted + to torch tensors. + + Returns: + Any: A new struct with the same structure as `stats`, but with all + values converted to torch Tensor types. + """ + + def mapping(item): + # Already torch tensor -> make sure it's on right device. + if torch.is_tensor(item): + return item if device is None else item.to(device) + # Special handling of "Repeated" values. + elif isinstance(item, RepeatedValues): + return RepeatedValues( + tree.map_structure(mapping, item.values), item.lengths, + item.max_len) + # Numpy arrays. + if isinstance(item, np.ndarray): + # np.object_ type (e.g. info dicts in train batch): leave as-is. + if item.dtype == np.object_: + return item + # Non-writable numpy-arrays will cause PyTorch warning. + elif item.flags.writeable is False: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + tensor = torch.from_numpy(item) + # Already numpy: Wrap as torch tensor. + else: + tensor = torch.from_numpy(item) + # Everything else: Convert to numpy, then wrap as torch tensor. + else: + tensor = torch.from_numpy(np.asarray(item)) + # Floatify all float64 tensors. + if tensor.dtype == torch.double: + tensor = tensor.float() + return tensor if device is None else tensor.to(device) + + return tree.map_structure(mapping, x) + + +def explained_variance(y: TensorType, pred: TensorType) -> TensorType: + """Computes the explained variance for a pair of labels and predictions. + + The formula used is: + max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2)) + + Args: + y: The labels. + pred: The predictions. + + Returns: + The explained variance given a pair of labels and predictions. + """ + y_var = torch.var(y, dim=[0]) + diff_var = torch.var(y - pred, dim=[0]) + min_ = torch.tensor([-1.0]).to(pred.device) + return torch.max(min_, 1 - (diff_var / y_var))[0] + + +def global_norm(tensors: List[TensorType]) -> TensorType: + """Returns the global L2 norm over a list of tensors. + + output = sqrt(SUM(t ** 2 for t in tensors)), + where SUM reduces over all tensors and over all elements in tensors. + + Args: + tensors: The list of tensors to calculate the global norm over. + + Returns: + The global L2 norm over the given tensor list. + """ + # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor. + single_l2s = [ + torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors + ] + # Compute global norm from all single tensors' L2 norms. + return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5) + + +def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: + """Computes the huber loss for a given term and delta parameter. + + Reference: https://en.wikipedia.org/wiki/Huber_loss + Note that the factor of 0.5 is implicitly included in the calculation. + + Formula: + L = 0.5 * x^2 for small abs x (delta threshold) + L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold) + + Args: + x: The input term, e.g. a TD error. + delta: The delta parmameter in the above formula. + + Returns: + The Huber loss resulting from `x` and `delta`. + """ + return torch.where( + torch.abs(x) < delta, + torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta)) + + +def l2_loss(x: TensorType) -> TensorType: + """Computes half the L2 norm over a tensor's values without the sqrt. + + output = 0.5 * sum(x ** 2) + + Args: + x: The input tensor. + + Returns: + 0.5 times the L2 norm over the given tensor's values (w/o sqrt). + """ + return 0.5 * torch.sum(torch.pow(x, 2.0)) + + +def minimize_and_clip(optimizer: "torch.optim.Optimizer", + clip_val: float = 10.0) -> None: + """Clips grads found in `optimizer.param_groups` to given value in place. + + Ensures the norm of the gradients for each variable is clipped to + `clip_val`. + + Args: + optimizer: The torch.optim.Optimizer to get the variables from. + clip_val: The global norm clip value. Will clip around -clip_val and + +clip_val. + """ + # Loop through optimizer's variables and norm per variable. + for param_group in optimizer.param_groups: + for p in param_group["params"]: + if p.grad is not None: + torch.nn.utils.clip_grad_norm_(p.grad, clip_val) + + +def one_hot(x: TensorType, space: gym.Space) -> TensorType: + """Returns a one-hot tensor, given and int tensor and a space. + + Handles the MultiDiscrete case as well. + + Args: + x: The input tensor. + space: The space to use for generating the one-hot tensor. + + Returns: + The resulting one-hot tensor. + + Raises: + ValueError: If the given space is not a discrete one. + + Examples: + >>> x = torch.IntTensor([0, 3]) # batch-dim=2 + >>> # Discrete space with 4 (one-hot) slots per batch item. + >>> s = gym.spaces.Discrete(4) + >>> one_hot(x, s) + tensor([[1, 0, 0, 0], [0, 0, 0, 1]]) + + >>> x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1 + >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots + >>> # per batch item. + >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) + >>> one_hot(x, s) + tensor([[1, 0, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, 0, 0, 0]]) + """ + if isinstance(space, Discrete): + return nn.functional.one_hot(x.long(), space.n) + elif isinstance(space, MultiDiscrete): + return torch.cat( + [ + nn.functional.one_hot(x[:, i].long(), n) + for i, n in enumerate(space.nvec) + ], + dim=-1) + else: + raise ValueError("Unsupported space for `one_hot`: {}".format(space)) + + +def reduce_mean_ignore_inf(x: TensorType, + axis: Optional[int] = None) -> TensorType: + """Same as torch.mean() but ignores -inf values. + + Args: + x: The input tensor to reduce mean over. + axis: The axis over which to reduce. None for all axes. + + Returns: + The mean reduced inputs, ignoring inf values. + """ + mask = torch.ne(x, float("-inf")) + x_zeroed = torch.where(mask, x, torch.zeros_like(x)) + return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis) + + +def sequence_mask( + lengths: TensorType, + maxlen: Optional[int] = None, + dtype=None, + time_major: bool = False, +) -> TensorType: + """Offers same behavior as tf.sequence_mask for torch. + + Thanks to Dimitris Papatheodorou + (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ + 39036). + + Args: + lengths: The tensor of individual lengths to mask by. + maxlen: The maximum length to use for the time axis. If None, use + the max of `lengths`. + dtype: The torch dtype to use for the resulting mask. + time_major: Whether to return the mask as [B, T] (False; default) or + as [T, B] (True). + + Returns: + The sequence mask resulting from the given input and parameters. + """ + # If maxlen not given, use the longest lengths in the `lengths` tensor. + if maxlen is None: + maxlen = int(lengths.max()) + + mask = ~(torch.ones( + (len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths) + # Time major transformation. + if not time_major: + mask = mask.t() + + # By default, set the mask to be boolean. + mask.type(dtype or torch.bool) + + return mask + + +def set_torch_seed(seed: Optional[int] = None) -> None: + """Sets the torch random seed to the given value. + + Args: + seed: The seed to use or None for no seeding. + """ + if seed is not None and torch: + torch.manual_seed(seed) + # See https://github.com/pytorch/pytorch/issues/47672. + cuda_version = torch.version.cuda + if cuda_version is not None and float(torch.version.cuda) >= 10.2: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8" + else: + # Not all Operations support this. + torch.use_deterministic_algorithms(True) + # This is only for Convolution no problem. + torch.backends.cudnn.deterministic = True + + +def softmax_cross_entropy_with_logits( + logits: TensorType, + labels: TensorType, +) -> TensorType: + """Same behavior as tf.nn.softmax_cross_entropy_with_logits. + + Args: + x: The input predictions. + labels: The labels corresponding to `x`. + + Returns: + The resulting softmax cross-entropy given predictions and labels. + """ + return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1) + + +class Swish(nn.Module): + def __init__(self): + super().__init__() + self._beta = nn.Parameter(torch.tensor(1.0)) + + def forward(self, input_tensor): + return input_tensor * torch.sigmoid(self._beta * input_tensor) diff --git a/rllib/utils/window_stat.py b/rllib/utils/window_stat.py index 9aa0d9f301df..873d803914af 100644 --- a/rllib/utils/window_stat.py +++ b/rllib/utils/window_stat.py @@ -1,28 +1,9 @@ -import numpy as np - - -class WindowStat: - def __init__(self, name, n): - self.name = name - self.items = [None] * n - self.idx = 0 - self.count = 0 - - def push(self, obj): - self.items[self.idx] = obj - self.idx += 1 - self.count += 1 - self.idx %= len(self.items) - - def stats(self): - if not self.count: - _quantiles = [] - else: - _quantiles = np.nanpercentile(self.items[:self.count], - [0, 10, 50, 90, 100]).tolist() - return { - self.name + "_count": int(self.count), - self.name + "_mean": float(np.nanmean(self.items[:self.count])), - self.name + "_std": float(np.nanstd(self.items[:self.count])), - self.name + "_quantiles": _quantiles, - } +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.metrics.window_stat import WindowStat + +deprecation_warning( + old="ray.rllib.utils.window_stat.WindowStat", + new="ray.rllib.utils.metrics.window_stat.WindowStat", + error=False, +) +WindowStat = WindowStat From 05c63f02082b715b8ff1adf001d86faa243307b3 Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Mon, 1 Nov 2021 15:50:38 -0700 Subject: [PATCH 08/15] [workflow] Mark workflow test_recovery as large test (#19950) ## Why are these changes needed? move test_recovery to large test ## Related issue number --- python/ray/workflow/BUILD | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/ray/workflow/BUILD b/python/ray/workflow/BUILD index 3b04110be7dc..b995b61d9ef9 100644 --- a/python/ray/workflow/BUILD +++ b/python/ray/workflow/BUILD @@ -13,10 +13,20 @@ SRCS = [] + select({ "//conditions:default": [], }) +LARGE_TESTS = ["tests/test_recovery.py"] + py_test_module_list( - files = glob(["tests/test_*.py", "examples/**/*.py"]), + files = glob(["tests/test_*.py", "examples/**/*.py"], exclude=LARGE_TESTS), size = "medium", extra_srcs = SRCS, tags = ["team:core", "exclusive"], deps = ["//:ray_lib"], ) + +py_test_module_list( + files = LARGE_TESTS, + size = "large", + extra_srcs = SRCS, + tags = ["team:core", "exclusive"], + deps = ["//:ray_lib"], +) From e1e4a45b8df8c898963faab3ce40f591d1ac1271 Mon Sep 17 00:00:00 2001 From: matthewdeng Date: Mon, 1 Nov 2021 18:25:19 -0700 Subject: [PATCH 09/15] [train] add simple Ray Train release tests (#19817) * [train] add simple Ray Train release tests * simplify tests * update * driver requirements * move to test * remove connect * fix * fix * fix torch * gpu * add assert * remove assert * use gloo backend * fix * finish Co-authored-by: Amog Kamsetty --- .../examples/tensorflow_mnist_example.py | 12 ++++-- .../train/examples/train_linear_example.py | 40 ++++++++++++++----- release/.buildkite/build_pipeline.py | 13 +++++- release/train_tests/app_config.yaml | 13 ++++++ release/train_tests/compute_tpl.yaml | 15 +++++++ release/train_tests/driver_requirements.txt | 8 ++++ release/train_tests/train_tests.yaml | 17 ++++++++ .../workloads/train_tensorflow_mnist_test.py | 31 ++++++++++++++ .../workloads/train_torch_linear_test.py | 30 ++++++++++++++ 9 files changed, 165 insertions(+), 14 deletions(-) create mode 100644 release/train_tests/app_config.yaml create mode 100644 release/train_tests/compute_tpl.yaml create mode 100644 release/train_tests/driver_requirements.txt create mode 100644 release/train_tests/train_tests.yaml create mode 100644 release/train_tests/workloads/train_tensorflow_mnist_test.py create mode 100644 release/train_tests/workloads/train_torch_linear_test.py diff --git a/python/ray/train/examples/tensorflow_mnist_example.py b/python/ray/train/examples/tensorflow_mnist_example.py index 5f71842bf78b..0880f3347cde 100644 --- a/python/ray/train/examples/tensorflow_mnist_example.py +++ b/python/ray/train/examples/tensorflow_mnist_example.py @@ -72,7 +72,7 @@ def train_func(config): return results -def train_tensorflow_mnist(num_workers=2, use_gpu=False): +def train_tensorflow_mnist(num_workers=2, use_gpu=False, epochs=4): trainer = Trainer( backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu) trainer.start() @@ -81,7 +81,7 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): config={ "lr": 1e-3, "batch_size": 64, - "epochs": 4 + "epochs": epochs }) trainer.shutdown() print(f"Results: {results[0]}") @@ -105,6 +105,8 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): action="store_true", default=False, help="Enables GPU training") + parser.add_argument( + "--epochs", type=int, default=3, help="Number of epochs to train for.") parser.add_argument( "--smoke-test", action="store_true", @@ -117,6 +119,10 @@ def train_tensorflow_mnist(num_workers=2, use_gpu=False): if args.smoke_test: ray.init(num_cpus=2) + train_tensorflow_mnist() else: ray.init(address=args.address) - train_tensorflow_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu) + train_tensorflow_mnist( + num_workers=args.num_workers, + use_gpu=args.use_gpu, + epochs=args.epochs) diff --git a/python/ray/train/examples/train_linear_example.py b/python/ray/train/examples/train_linear_example.py index 50bfbd0fe2aa..2512f022a596 100644 --- a/python/ray/train/examples/train_linear_example.py +++ b/python/ray/train/examples/train_linear_example.py @@ -25,8 +25,9 @@ def __len__(self): return len(self.x) -def train_epoch(dataloader, model, loss_fn, optimizer): +def train_epoch(dataloader, model, loss_fn, optimizer, device): for X, y in dataloader: + X, y = X.to(device), y.to(device) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) @@ -37,12 +38,13 @@ def train_epoch(dataloader, model, loss_fn, optimizer): optimizer.step() -def validate_epoch(dataloader, model, loss_fn): +def validate_epoch(dataloader, model, loss_fn, device): num_batches = len(dataloader) model.eval() loss = 0 with torch.no_grad(): for X, y in dataloader: + X, y = X.to(device), y.to(device) pred = model(X) loss += loss_fn(pred, y).item() loss /= num_batches @@ -58,6 +60,9 @@ def train_func(config): lr = config.get("lr", 1e-2) epochs = config.get("epochs", 3) + device = torch.device(f"cuda:{train.local_rank()}" + if torch.cuda.is_available() else "cpu") + train_dataset = LinearDataset(2, 5, size=data_size) val_dataset = LinearDataset(2, 5, size=val_size) train_loader = torch.utils.data.DataLoader( @@ -70,7 +75,10 @@ def train_func(config): sampler=DistributedSampler(val_dataset)) model = nn.Linear(1, hidden_size) - model = DistributedDataParallel(model) + model.to(device) + model = DistributedDataParallel( + model, + device_ids=[device.index] if torch.cuda.is_available() else None) loss_fn = nn.MSELoss() @@ -79,17 +87,20 @@ def train_func(config): results = [] for _ in range(epochs): - train_epoch(train_loader, model, loss_fn, optimizer) - result = validate_epoch(validation_loader, model, loss_fn) + train_epoch(train_loader, model, loss_fn, optimizer, device) + result = validate_epoch(validation_loader, model, loss_fn, device) train.report(**result) results.append(result) return results -def train_linear(num_workers=2): - trainer = Trainer(TorchConfig(backend="gloo"), num_workers=num_workers) - config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": 3} +def train_linear(num_workers=2, use_gpu=False, epochs=3): + trainer = Trainer( + backend=TorchConfig(backend="gloo"), + num_workers=num_workers, + use_gpu=use_gpu) + config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs} trainer.start() results = trainer.run( train_func, @@ -115,6 +126,12 @@ def train_linear(num_workers=2): type=int, default=2, help="Sets number of workers for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + help="Whether to use GPU for training.") + parser.add_argument( + "--epochs", type=int, default=3, help="Number of epochs to train for.") parser.add_argument( "--smoke-test", action="store_true", @@ -127,7 +144,10 @@ def train_linear(num_workers=2): if args.smoke_test: ray.init(num_cpus=2) + train_linear() else: ray.init(address=args.address) - - train_linear(num_workers=args.num_workers) + train_linear( + num_workers=args.num_workers, + use_gpu=args.use_gpu, + epochs=args.epochs) diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index 5faa766d8190..c2de1fe25302 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -250,7 +250,18 @@ def __init__(self, # 2. Use autoscaling/scale up (no wait_cluster.py) # 3. Use GPUs if applicable # 4. Have the `use_connect` flag set. -USER_TESTS = {} +USER_TESTS = { + "~/ray/release/train_tests/train_tests.yaml": [ + ConnectTest( + "train_tensorflow_mnist_test", + requirements_file="release/train_tests" + "/driver_requirements.txt"), + ConnectTest( + "train_torch_linear_test", + requirements_file="release/train_tests" + "/driver_requirements.txt") + ] +} SUITES = { "core-nightly": CORE_NIGHTLY_TESTS, diff --git a/release/train_tests/app_config.yaml b/release/train_tests/app_config.yaml new file mode 100644 index 000000000000..446b53847dfa --- /dev/null +++ b/release/train_tests/app_config.yaml @@ -0,0 +1,13 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: { } +debian_packages: + - curl + +python: + pip_packages: [ ] + conda_packages: [ ] + +post_build_cmds: + - pip3 uninstall -y ray || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/train_tests/compute_tpl.yaml b/release/train_tests/compute_tpl.yaml new file mode 100644 index 000000000000..221bb8f66548 --- /dev/null +++ b/release/train_tests/compute_tpl.yaml @@ -0,0 +1,15 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 2 + +head_node_type: + name: head_node + instance_type: g3.8xlarge + +worker_node_types: + - name: worker_node + instance_type: g3.8xlarge + min_workers: 0 + max_workers: 2 + use_spot: false diff --git a/release/train_tests/driver_requirements.txt b/release/train_tests/driver_requirements.txt new file mode 100644 index 000000000000..b779be3a31dd --- /dev/null +++ b/release/train_tests/driver_requirements.txt @@ -0,0 +1,8 @@ +# Make sure the driver versions are the same as cluster versions. +# The cluster uses ray-ml Docker image. +# ray-ml Docker image installs dependencies from ray/python/requirements/ml/ directory. +# We constrain on these requirements file so that the same versions are installed. +-c ../../python/requirements/ml/requirements_dl.txt + +torch +tensorflow \ No newline at end of file diff --git a/release/train_tests/train_tests.yaml b/release/train_tests/train_tests.yaml new file mode 100644 index 000000000000..c19493f85d56 --- /dev/null +++ b/release/train_tests/train_tests.yaml @@ -0,0 +1,17 @@ +- name: train_tensorflow_mnist_test + cluster: + app_config: app_config.yaml + compute_template: compute_tpl.yaml + + run: + timeout: 36000 + script: python workloads/train_tensorflow_mnist_test.py + +- name: train_torch_linear_test + cluster: + app_config: app_config.yaml + compute_template: compute_tpl.yaml + + run: + timeout: 36000 + script: python workloads/train_torch_linear_test.py diff --git a/release/train_tests/workloads/train_tensorflow_mnist_test.py b/release/train_tests/workloads/train_tensorflow_mnist_test.py new file mode 100644 index 000000000000..376979d93c3a --- /dev/null +++ b/release/train_tests/workloads/train_tensorflow_mnist_test.py @@ -0,0 +1,31 @@ +import json +import os +import time + +import ray +from ray.train.examples.tensorflow_mnist_example import train_tensorflow_mnist + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + train_tensorflow_mnist(num_workers=6, use_gpu=True, epochs=20) + + taken = time.time() - start + result = { + "time_taken": taken, + } + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/train_torc_linear_test.json") + + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") diff --git a/release/train_tests/workloads/train_torch_linear_test.py b/release/train_tests/workloads/train_torch_linear_test.py new file mode 100644 index 000000000000..fe013a8ef971 --- /dev/null +++ b/release/train_tests/workloads/train_torch_linear_test.py @@ -0,0 +1,30 @@ +import json +import os +import time + +import ray + +from ray.train.examples.train_linear_example import train_linear + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + results = train_linear(num_workers=6, use_gpu=True, epochs=20) + + taken = time.time() - start + result = {"time_taken": taken} + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/train_torc_linear_test.json") + + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") From 474e44f7e01ed7661a244b62356be76cf9cf1131 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 1 Nov 2021 18:28:07 -0700 Subject: [PATCH 10/15] [Release/Horovod] Add user test for Horovod (#19661) * infra * wip * add test * typo * typo * update * rename * fix * full path * formatting * reorder * update * update * Update release/horovod_tests/workloads/horovod_user_test.py Co-authored-by: matthewdeng * bump num_workers * update installs * try * add pip_packages * min_workers * fix * bump pg timeout * Fix symlink * fix * fix * cmake * fix * pin filelock * final * update * fix * Update release/horovod_tests/workloads/horovod_user_test.py * fix * fix * separate compute template * test latest and master Co-authored-by: matthewdeng --- .../examples/horovod_cifar_pbt_example.py | 2 +- python/ray/util/horovod/horovod_example.py | 15 ++++-- release/.buildkite/build_pipeline.py | 52 ++++++++++++++----- release/horovod_tests/app_config.yaml | 4 +- release/horovod_tests/app_config_master.yaml | 20 +++++++ release/horovod_tests/base_driver_reqs.txt | 8 +++ release/horovod_tests/compute_tpl.yaml | 2 +- .../compute_tpl_autoscaling.yaml | 24 +++++++++ release/horovod_tests/driver_requirements.txt | 3 ++ .../driver_requirements_master.txt | 4 ++ release/horovod_tests/horovod_tests.yaml | 24 ++++++++- .../{horovod_test.py => horovod_tune_test.py} | 0 .../workloads/horovod_user_test.py | 33 ++++++++++++ 13 files changed, 171 insertions(+), 20 deletions(-) create mode 100644 release/horovod_tests/app_config_master.yaml create mode 100644 release/horovod_tests/base_driver_reqs.txt create mode 100644 release/horovod_tests/compute_tpl_autoscaling.yaml create mode 100644 release/horovod_tests/driver_requirements.txt create mode 100644 release/horovod_tests/driver_requirements_master.txt rename release/horovod_tests/workloads/{horovod_test.py => horovod_tune_test.py} (100%) create mode 100644 release/horovod_tests/workloads/horovod_user_test.py diff --git a/python/ray/tune/examples/horovod_cifar_pbt_example.py b/python/ray/tune/examples/horovod_cifar_pbt_example.py index 4bd6bd44dd8d..d66c985f4f1e 120000 --- a/python/ray/tune/examples/horovod_cifar_pbt_example.py +++ b/python/ray/tune/examples/horovod_cifar_pbt_example.py @@ -1 +1 @@ -../../../../release/horovod_tests/workloads/horovod_test.py \ No newline at end of file +../../../../release/horovod_tests/workloads/horovod_tune_test.py \ No newline at end of file diff --git a/python/ray/util/horovod/horovod_example.py b/python/ray/util/horovod/horovod_example.py index 59aa0850245e..1e285d0128af 100644 --- a/python/ray/util/horovod/horovod_example.py +++ b/python/ray/util/horovod/horovod_example.py @@ -115,11 +115,20 @@ def train_fn(data_dir=None, 100. * batch_idx / len(train_loader), loss.item())) -def main(num_workers, use_gpu, **kwargs): - settings = RayExecutor.create_settings(timeout_s=30) +def main(num_workers, + use_gpu, + timeout_s=30, + placement_group_timeout_s=100, + kwargs=None): + kwargs = kwargs or {} + if use_gpu: + kwargs["use_cuda"] = True + settings = RayExecutor.create_settings( + timeout_s=timeout_s, + placement_group_timeout_s=placement_group_timeout_s) executor = RayExecutor(settings, use_gpu=use_gpu, num_workers=num_workers) executor.start() - executor.run(train_fn, **kwargs) + executor.run(train_fn, kwargs=kwargs) if __name__ == "__main__": diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index c2de1fe25302..e68f4bb8b2bf 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -20,7 +20,12 @@ class ReleaseTest: - def __init__(self, name: str, smoke_test: bool = False, retry: int = 0): + def __init__( + self, + name: str, + smoke_test: bool = False, + retry: int = 0, + ): self.name = name self.smoke_test = smoke_test self.retry = retry @@ -243,6 +248,19 @@ def __init__(self, ], } +HOROVOD_INSTALL_ENV_VARS = [ + "HOROVOD_WITH_GLOO", "HOROVOD_WITHOUT_MPI", "HOROVOD_WITHOUT_TENSORFLOW", + "HOROVOD_WITHOUT_MXNET", "HOROVOD_WITH_PYTORCH" +] + +HOROVOD_SETUP_COMMANDS = [ + "sudo apt update", "sudo apt -y install build-essential", + "pip install cmake" +] + [ + f"export {horovod_env_var}=1" + for horovod_env_var in HOROVOD_INSTALL_ENV_VARS +] + # This test suite holds "user" tests to test important user workflows # in a particular environment. # All workloads in this test suite should: @@ -251,6 +269,17 @@ def __init__(self, # 3. Use GPUs if applicable # 4. Have the `use_connect` flag set. USER_TESTS = { + "~/ray/release/horovod_tests/horovod_tests.yaml": [ + ConnectTest( + "horovod_user_test_latest", + setup_commands=HOROVOD_SETUP_COMMANDS, + requirements_file="release/horovod_tests/driver_requirements.txt"), + ConnectTest( + "horovod_user_test_master", + setup_commands=HOROVOD_SETUP_COMMANDS, + requirements_file="release/horovod_tests" + "/driver_requirements_master.txt") + ], "~/ray/release/train_tests/train_tests.yaml": [ ConnectTest( "train_tensorflow_mnist_test", @@ -260,7 +289,7 @@ def __init__(self, "train_torch_linear_test", requirements_file="release/train_tests" "/driver_requirements.txt") - ] + ], } SUITES = { @@ -484,22 +513,21 @@ def create_test_step( }] } - step_conf["commands"] = [ - "pip install -q -r release/requirements.txt", - "pip install -U boto3 botocore", - f"git clone -b {ray_test_branch} {ray_test_repo} ~/ray", cmd, - "sudo cp -rf /tmp/artifacts/* /tmp/ray_release_test_artifacts " - "|| true" - ] - if isinstance(test_name, ConnectTest): # Add driver side setup commands to the step. pip_requirements_command = [f"pip install -U -r " f"{test_name.requirements_file}"] if \ test_name.requirements_file else [] step_conf["commands"] = test_name.setup_commands \ - + pip_requirements_command \ - + step_conf["commands"] + + pip_requirements_command + + step_conf["commands"] += [ + "pip install -q -r release/requirements.txt", + "pip install -U boto3 botocore", + f"git clone -b {ray_test_branch} {ray_test_repo} ~/ray", cmd, + "sudo cp -rf /tmp/artifacts/* /tmp/ray_release_test_artifacts " + "|| true" + ] step_conf["label"] = ( f"{test_name} " diff --git a/release/horovod_tests/app_config.yaml b/release/horovod_tests/app_config.yaml index 15dd1051603e..6678b4beb922 100644 --- a/release/horovod_tests/app_config.yaml +++ b/release/horovod_tests/app_config.yaml @@ -14,7 +14,7 @@ post_build_cmds: - sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy - pip3 install numpy || true - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} - - pip3 install 'ray[rllib]' + - pip3 install 'ray[tune]' - pip3 install torch torchvision - - HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip3 install -U git+https://github.com/horovod/horovod.git + - HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip3 install -U horovod - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/horovod_tests/app_config_master.yaml b/release/horovod_tests/app_config_master.yaml new file mode 100644 index 000000000000..c53c0e981fa9 --- /dev/null +++ b/release/horovod_tests/app_config_master.yaml @@ -0,0 +1,20 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: {} +debian_packages: + - curl + +python: + pip_packages: + - pytest + - awscli + conda_packages: [] + +post_build_cmds: + - pip uninstall -y numpy ray || true + - sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy + - pip3 install numpy || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - pip3 install 'ray[tune]' + - pip3 install torch torchvision + - HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip3 install -U git+https://github.com/horovod/horovod.git + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/horovod_tests/base_driver_reqs.txt b/release/horovod_tests/base_driver_reqs.txt new file mode 100644 index 000000000000..057dffa317b2 --- /dev/null +++ b/release/horovod_tests/base_driver_reqs.txt @@ -0,0 +1,8 @@ +# Make sure the driver versions are the same as cluster versions. +# The cluster uses ray-ml Docker image. +# ray-ml Docker image installs dependencies from ray/python/requirements/ml/ directory. +# We constrain on these requirements file so that the same versions are installed. +-c ../../python/requirements/ml/requirements_dl.txt + +torch +torchvision \ No newline at end of file diff --git a/release/horovod_tests/compute_tpl.yaml b/release/horovod_tests/compute_tpl.yaml index 1d6d1686a0a7..3a5b4428d90b 100644 --- a/release/horovod_tests/compute_tpl.yaml +++ b/release/horovod_tests/compute_tpl.yaml @@ -10,8 +10,8 @@ head_node_type: worker_node_types: - name: worker_node instance_type: g3.8xlarge - min_workers: 3 max_workers: 3 + min_workers: 3 use_spot: false aws: diff --git a/release/horovod_tests/compute_tpl_autoscaling.yaml b/release/horovod_tests/compute_tpl_autoscaling.yaml new file mode 100644 index 000000000000..7f156c8756ab --- /dev/null +++ b/release/horovod_tests/compute_tpl_autoscaling.yaml @@ -0,0 +1,24 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 3 + +head_node_type: + name: head_node + instance_type: g3.8xlarge + +worker_node_types: + - name: worker_node + instance_type: g3.8xlarge + max_workers: 3 + min_workers: 0 + use_spot: false + +aws: + TagSpecifications: + - ResourceType: "instance" + Tags: + - Key: anyscale-user + Value: '{{env["ANYSCALE_USER"]}}' + - Key: anyscale-expiration + Value: '{{env["EXPIRATION_1D"]}}' diff --git a/release/horovod_tests/driver_requirements.txt b/release/horovod_tests/driver_requirements.txt new file mode 100644 index 000000000000..8ce2c8a59a29 --- /dev/null +++ b/release/horovod_tests/driver_requirements.txt @@ -0,0 +1,3 @@ +-r ./base_driver_reqs.txt + +horovod diff --git a/release/horovod_tests/driver_requirements_master.txt b/release/horovod_tests/driver_requirements_master.txt new file mode 100644 index 000000000000..5fc9bfa194a5 --- /dev/null +++ b/release/horovod_tests/driver_requirements_master.txt @@ -0,0 +1,4 @@ +-r ./base_driver_reqs.txt + +# Horovod master. +git+https://github.com/horovod/horovod.git \ No newline at end of file diff --git a/release/horovod_tests/horovod_tests.yaml b/release/horovod_tests/horovod_tests.yaml index 0ddcd5b7bf12..9d3815d8315a 100644 --- a/release/horovod_tests/horovod_tests.yaml +++ b/release/horovod_tests/horovod_tests.yaml @@ -1,6 +1,6 @@ - name: horovod_test cluster: - app_config: app_config.yaml + app_config: app_config_master.yaml compute_template: compute_tpl.yaml run: @@ -12,3 +12,25 @@ smoke_test: run: timeout: 1800 + +- name: horovod_user_test_latest + cluster: + app_config: app_config.yaml + compute_template: compute_tpl_autoscaling.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/horovod_user_test.py + +- name: horovod_user_test_master + cluster: + app_config: app_config_master.yaml + compute_template: compute_tpl_autoscaling.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/horovod_user_test.py diff --git a/release/horovod_tests/workloads/horovod_test.py b/release/horovod_tests/workloads/horovod_tune_test.py similarity index 100% rename from release/horovod_tests/workloads/horovod_test.py rename to release/horovod_tests/workloads/horovod_tune_test.py diff --git a/release/horovod_tests/workloads/horovod_user_test.py b/release/horovod_tests/workloads/horovod_user_test.py new file mode 100644 index 000000000000..f1b53e350df2 --- /dev/null +++ b/release/horovod_tests/workloads/horovod_user_test.py @@ -0,0 +1,33 @@ +import json +import os +import time + +import ray +from ray.util.horovod.horovod_example import main + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + main( + num_workers=6, + use_gpu=True, + placement_group_timeout_s=900, + kwargs={"num_epochs": 20}) + + taken = time.time() - start + result = { + "time_taken": taken, + } + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/horovod_user_test.json") + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") From 3a52187da8b7bdcbb66036efb22becb961365ba4 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 1 Nov 2021 18:29:48 -0700 Subject: [PATCH 11/15] [Release/Lightning] Add Ray lightning user test (#19812) * wip * wip * add ray lightning test * fix * update * merge and add * fix * fix * rename * autoscale * add tblib * gloo backend * typo * upgrade torch * latest and master --- python/ray/util/ray_lightning/BUILD | 2 +- .../ray/util/ray_lightning/simple_example.py | 33 +++++++++++++++++-- python/requirements_ml_docker.txt | 5 +++ release/.buildkite/build_pipeline.py | 10 ++++++ release/ray_lightning_tests/app_config.yaml | 20 +++++++++++ .../app_config_master.yaml | 20 +++++++++++ release/ray_lightning_tests/compute_tpl.yaml | 24 ++++++++++++++ .../driver_requirements.txt | 9 +++++ .../ray_lightning_tests.yaml | 22 +++++++++++++ .../workloads/ray_lightning_user_test.py | 29 ++++++++++++++++ 10 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 release/ray_lightning_tests/app_config.yaml create mode 100644 release/ray_lightning_tests/app_config_master.yaml create mode 100644 release/ray_lightning_tests/compute_tpl.yaml create mode 100644 release/ray_lightning_tests/driver_requirements.txt create mode 100644 release/ray_lightning_tests/ray_lightning_tests.yaml create mode 100644 release/ray_lightning_tests/workloads/ray_lightning_user_test.py diff --git a/python/ray/util/ray_lightning/BUILD b/python/ray/util/ray_lightning/BUILD index 75b45d271927..f4aaac2a00c0 100644 --- a/python/ray/util/ray_lightning/BUILD +++ b/python/ray/util/ray_lightning/BUILD @@ -1,5 +1,5 @@ # -------------------------------------------------------------------- -# Tests from the python/ray/util/xgboost directory. +# Tests from the python/ray/util/ray_lightning directory. # Please keep these sorted alphabetically. # -------------------------------------------------------------------- py_test( diff --git a/python/ray/util/ray_lightning/simple_example.py b/python/ray/util/ray_lightning/simple_example.py index 9b8b728364c8..9064fb0d208a 100644 --- a/python/ray/util/ray_lightning/simple_example.py +++ b/python/ray/util/ray_lightning/simple_example.py @@ -1,3 +1,4 @@ +import argparse import os import torch from torch import nn @@ -38,15 +39,41 @@ def configure_optimizers(self): return optimizer -def main(): +def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10): dataset = MNIST( os.getcwd(), download=True, transform=transforms.ToTensor()) train, val = random_split(dataset, [55000, 5000]) autoencoder = LitAutoEncoder() - trainer = pl.Trainer(plugins=[RayPlugin(num_workers=2)], max_steps=10) + trainer = pl.Trainer( + plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)], + max_steps=max_steps) trainer.fit(autoencoder, DataLoader(train), DataLoader(val)) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="Ray Lightning Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of workers to use for training.") + parser.add_argument( + "--max-steps", + type=int, + default=10, + help="Maximum number of steps to run for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="Whether to enable GPU training.") + + args = parser.parse_args() + + main( + num_workers=args.num_workers, + max_steps=args.max_steps, + use_gpu=args.use_gpu) diff --git a/python/requirements_ml_docker.txt b/python/requirements_ml_docker.txt index f30397ede599..ede4f9b30ef0 100644 --- a/python/requirements_ml_docker.txt +++ b/python/requirements_ml_docker.txt @@ -1,4 +1,9 @@ ipython + +# Needed for Ray Client error message serialization/deserialization. +tblib + + # In TF >v2, GPU support is included in the base package. tensorflow==2.5.0 tensorflow-probability==0.13.0 diff --git a/release/.buildkite/build_pipeline.py b/release/.buildkite/build_pipeline.py index e68f4bb8b2bf..d3faa2ccc872 100644 --- a/release/.buildkite/build_pipeline.py +++ b/release/.buildkite/build_pipeline.py @@ -269,6 +269,16 @@ def __init__(self, # 3. Use GPUs if applicable # 4. Have the `use_connect` flag set. USER_TESTS = { + "~/ray/release/ray_lightning_tests/ray_lightning_tests.yaml": [ + ConnectTest( + "ray_lightning_user_test_latest", + requirements_file="release/ray_lightning_tests" + "/driver_requirements.txt"), + ConnectTest( + "ray_lightning_user_test_master", + requirements_file="release/ray_lightning_tests" + "/driver_requirements.txt") + ], "~/ray/release/horovod_tests/horovod_tests.yaml": [ ConnectTest( "horovod_user_test_latest", diff --git a/release/ray_lightning_tests/app_config.yaml b/release/ray_lightning_tests/app_config.yaml new file mode 100644 index 000000000000..e3935fe41be0 --- /dev/null +++ b/release/ray_lightning_tests/app_config.yaml @@ -0,0 +1,20 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: + PL_TORCH_DISTRIBUTED_BACKEND: gloo + +debian_packages: + - curl + +python: + pip_packages: + # TODO(amogkam): Remove the tblib, torch, and torchvision installs once we use nightly image. + - tblib + - torch==1.9.0 + - torchvision==0.10.0 + - ray-lightning + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/ray_lightning_tests/app_config_master.yaml b/release/ray_lightning_tests/app_config_master.yaml new file mode 100644 index 000000000000..d99cb56e0fa6 --- /dev/null +++ b/release/ray_lightning_tests/app_config_master.yaml @@ -0,0 +1,20 @@ +base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu" +env_vars: + PL_TORCH_DISTRIBUTED_BACKEND: gloo + +debian_packages: + - curl + +python: + pip_packages: + # TODO(amogkam): Remove the tblib, torch, and torchvision installs once we use nightly image. + - tblib + - torch==1.9.0 + - torchvision==0.10.0 + - git+https://github.com/ray-project/ray_lightning#ray_lightning + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray || true + - pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} diff --git a/release/ray_lightning_tests/compute_tpl.yaml b/release/ray_lightning_tests/compute_tpl.yaml new file mode 100644 index 000000000000..7809c13e7761 --- /dev/null +++ b/release/ray_lightning_tests/compute_tpl.yaml @@ -0,0 +1,24 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +max_workers: 3 + +head_node_type: + name: head_node + instance_type: g3.8xlarge + +worker_node_types: + - name: worker_node + instance_type: g3.8xlarge + min_workers: 0 + max_workers: 2 + use_spot: false + +aws: + TagSpecifications: + - ResourceType: "instance" + Tags: + - Key: anyscale-user + Value: '{{env["ANYSCALE_USER"]}}' + - Key: anyscale-expiration + Value: '{{env["EXPIRATION_1D"]}}' diff --git a/release/ray_lightning_tests/driver_requirements.txt b/release/ray_lightning_tests/driver_requirements.txt new file mode 100644 index 000000000000..ba19e088e192 --- /dev/null +++ b/release/ray_lightning_tests/driver_requirements.txt @@ -0,0 +1,9 @@ +# Make sure the driver versions are the same as cluster versions. +# The cluster uses ray-ml Docker image. +# ray-ml Docker image installs dependencies from ray/python/requirements/ml/ directory. +# We constrain on these requirements file so that the same versions are installed. +-c ../../python/requirements/ml/requirements_dl.txt + +torch +torchvision +pytorch-lightning \ No newline at end of file diff --git a/release/ray_lightning_tests/ray_lightning_tests.yaml b/release/ray_lightning_tests/ray_lightning_tests.yaml new file mode 100644 index 000000000000..5eab5a9605d7 --- /dev/null +++ b/release/ray_lightning_tests/ray_lightning_tests.yaml @@ -0,0 +1,22 @@ +- name: ray_lightning_user_test_latest + cluster: + app_config: app_config.yaml + compute_template: compute_tpl.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/ray_lightning_user_test.py + + +- name: ray_lightning_user_test_master + cluster: + app_config: app_config_master.yaml + compute_template: compute_tpl.yaml + + run: + use_connect: True + autosuspend_mins: 10 + timeout: 1200 + script: python workloads/ray_lightning_user_test.py \ No newline at end of file diff --git a/release/ray_lightning_tests/workloads/ray_lightning_user_test.py b/release/ray_lightning_tests/workloads/ray_lightning_user_test.py new file mode 100644 index 000000000000..211e4cd96209 --- /dev/null +++ b/release/ray_lightning_tests/workloads/ray_lightning_user_test.py @@ -0,0 +1,29 @@ +import json +import os +import time + +import ray +from ray.util.ray_lightning.simple_example import main + +if __name__ == "__main__": + start = time.time() + + addr = os.environ.get("RAY_ADDRESS") + job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test") + if addr is not None and addr.startswith("anyscale://"): + ray.init(address=addr, job_name=job_name) + else: + ray.init(address="auto") + + main(num_workers=6, use_gpu=True, max_steps=50) + + taken = time.time() - start + result = { + "time_taken": taken, + } + test_output_json = os.environ.get("TEST_OUTPUT_JSON", + "/tmp/ray_lightning_user_test.json") + with open(test_output_json, "wt") as f: + json.dump(result, f) + + print("Test Successful!") From c48d86e4695ead18c2cf46cd705836249fb69224 Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Mon, 1 Nov 2021 19:38:58 -0700 Subject: [PATCH 12/15] [CI] change git protocol to use https. (#19964) --- .buildkite/windows/install/reqs.txt | 4 ++-- doc/requirements-doc.txt | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.buildkite/windows/install/reqs.txt b/.buildkite/windows/install/reqs.txt index d4bff5c46644..76a574868235 100644 --- a/.buildkite/windows/install/reqs.txt +++ b/.buildkite/windows/install/reqs.txt @@ -41,8 +41,8 @@ pytest-tornasync pytest-trio pytest-twisted werkzeug -git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn -git+git://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray +git+https://github.com/ray-project/tune-sklearn@master#tune-sklearn +git+https://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray scikit-optimize tensorflow gym diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 0471886e0159..829cd4c6d4c7 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -31,9 +31,9 @@ starlette tabulate uvicorn werkzeug -git+git://github.com/ray-project/tune-sklearn@master#tune-sklearn -git+git://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray -git+git://github.com/ray-project/lightgbm_ray@main#lightgbm_ray +git+https://github.com/ray-project/tune-sklearn@master#tune-sklearn +git+https://github.com/ray-project/xgboost_ray@master#egg=xgboost_ray +git+https://github.com/ray-project/lightgbm_ray@main#lightgbm_ray git+https://github.com/ray-project/ray_lightning#ray_lightning scikit-optimize sphinx-sitemap==2.2.0 From a907168184ee233a19d4397c1185f3a1488b12d2 Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+iycheng@users.noreply.github.com> Date: Mon, 1 Nov 2021 19:52:03 -0700 Subject: [PATCH 13/15] [core] Fix wrong local resource view in raylet (#19911) ## Why are these changes needed? When gcs broad cast node resource change, raylet will use that to update local node as well which will lead to local node instance and nodes_ inconsistent. 1. local node has used all some pg resource 2. gcs broadcast node resources 3. local node now have resources 4. scheduler picks local node 5. local node can't schedule the task 6. since there is only one type of job and local nodes hasn't finished any tasks so it'll go to step 4 ==> hangs ## Related issue number #19438 --- python/ray/tests/test_placement_group_3.py | 52 +++++++++++++++++++ .../gcs/gcs_server/gcs_resource_manager.cc | 2 +- src/ray/raylet/node_manager.cc | 40 +++++++------- 3 files changed, 75 insertions(+), 19 deletions(-) diff --git a/python/ray/tests/test_placement_group_3.py b/python/ray/tests/test_placement_group_3.py index cee2b819c1a6..57009bbd90f4 100644 --- a/python/ray/tests/test_placement_group_3.py +++ b/python/ray/tests/test_placement_group_3.py @@ -646,5 +646,57 @@ def check_bundle_leaks(): wait_for_condition(check_bundle_leaks) +def test_placement_group_local_resource_view(monkeypatch, ray_start_cluster): + """Please refer to https://github.com/ray-project/ray/pull/19911 + for more details. + """ + with monkeypatch.context() as m: + # Increase broadcasting interval so that node resource will arrive + # at raylet after local resource all being allocated. + m.setenv("RAY_raylet_report_resources_period_milliseconds", "2000") + m.setenv("RAY_grpc_based_resource_broadcast", "true") + cluster = ray_start_cluster + + cluster.add_node(num_cpus=16, object_store_memory=1e9) + cluster.wait_for_nodes() + cluster.add_node(num_cpus=16, num_gpus=1) + cluster.wait_for_nodes() + NUM_CPU_BUNDLES = 30 + + @ray.remote(num_cpus=1) + class Worker(object): + def __init__(self, i): + self.i = i + + def work(self): + time.sleep(0.1) + print("work ", self.i) + + @ray.remote(num_cpus=1, num_gpus=1) + class Trainer(object): + def __init__(self, i): + self.i = i + + def train(self): + time.sleep(0.2) + print("train ", self.i) + + ray.init(address="auto") + bundles = [{"CPU": 1, "GPU": 1}] + bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] + pg = placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + + # Local resource will be allocated and here we are to ensure + # local view is consistent and node resouce updates are discarded + workers = [ + Worker.options(placement_group=pg).remote(i) + for i in range(NUM_CPU_BUNDLES) + ] + trainer = Trainer.options(placement_group=pg).remote(0) + ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)]) + ray.get(trainer.train.remote()) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.cc b/src/ray/gcs/gcs_server/gcs_resource_manager.cc index 9e36db6efafa..4266c635bf98 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.cc @@ -58,7 +58,7 @@ void GcsResourceManager::HandleUpdateResources( const rpc::UpdateResourcesRequest &request, rpc::UpdateResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { NodeID node_id = NodeID::FromBinary(request.node_id()); - RAY_LOG(INFO) << "Updating resources, node id = " << node_id; + RAY_LOG(DEBUG) << "Updating resources, node id = " << node_id; auto changed_resources = std::make_shared>(); for (const auto &entry : request.resources()) { changed_resources->emplace(entry.first, entry.second.resource_capacity()); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 260f259657de..f184bf6f38fc 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -891,6 +891,9 @@ void NodeManager::ResourceCreateUpdated(const NodeID &node_id, RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from node id " << node_id << " with created or updated resources: " << createUpdatedResources.ToString() << ". Updating resource map."; + if (node_id == self_node_id_) { + return; + } // Update local_available_resources_ and SchedulingResources for (const auto &resource_pair : createUpdatedResources.GetResourceMap()) { @@ -900,11 +903,7 @@ void NodeManager::ResourceCreateUpdated(const NodeID &node_id, new_resource_capacity); } RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; - - if (node_id == self_node_id_) { - // The resource update is on the local node, check if we can reschedule tasks. - cluster_task_manager_->ScheduleAndDispatchTasks(); - } + cluster_task_manager_->ScheduleAndDispatchTasks(); } void NodeManager::ResourceDeleted(const NodeID &node_id, @@ -1474,39 +1473,44 @@ void NodeManager::HandleUpdateResourceUsage( rpc::SendReplyCallback send_reply_callback) { rpc::ResourceUsageBroadcastData resource_usage_batch; resource_usage_batch.ParseFromString(request.serialized_resource_usage_batch()); - - if (resource_usage_batch.seq_no() != next_resource_seq_no_) { + // When next_resource_seq_no_ == 0 it means it just started. + // TODO: Fetch a snapshot from gcs for lightweight resource broadcasting + if (next_resource_seq_no_ != 0 && + resource_usage_batch.seq_no() != next_resource_seq_no_) { + // TODO (Alex): Ideally we would be really robust, and potentially eagerly + // pull a full resource "snapshot" from gcs to make sure our state doesn't + // diverge from GCS. RAY_LOG(WARNING) << "Raylet may have missed a resource broadcast. This either means that GCS has " "restarted, the network is heavily congested and is dropping, reordering, or " "duplicating packets. Expected seq#: " << next_resource_seq_no_ << ", but got: " << resource_usage_batch.seq_no() << "."; - // TODO (Alex): Ideally we would be really robust, and potentially eagerly - // pull a full resource "snapshot" from gcs to make sure our state doesn't - // diverge from GCS. + if (resource_usage_batch.seq_no() < next_resource_seq_no_) { + RAY_LOG(WARNING) << "Discard the the resource update since local version is newer"; + return; + } } next_resource_seq_no_ = resource_usage_batch.seq_no() + 1; for (const auto &resource_change_or_data : resource_usage_batch.batch()) { if (resource_change_or_data.has_data()) { const auto &resource_usage = resource_change_or_data.data(); - const NodeID &node_id = NodeID::FromBinary(resource_usage.node_id()); - if (node_id == self_node_id_) { - // Skip messages from self. - continue; + auto node_id = NodeID::FromBinary(resource_usage.node_id()); + // Skip messages from self. + if (node_id != self_node_id_) { + UpdateResourceUsage(node_id, resource_usage); } - UpdateResourceUsage(node_id, resource_usage); } else if (resource_change_or_data.has_change()) { const auto &resource_notification = resource_change_or_data.change(); - auto id = NodeID::FromBinary(resource_notification.node_id()); + auto node_id = NodeID::FromBinary(resource_notification.node_id()); if (resource_notification.updated_resources_size() != 0) { ResourceSet resource_set( MapFromProtobuf(resource_notification.updated_resources())); - ResourceCreateUpdated(id, resource_set); + ResourceCreateUpdated(node_id, resource_set); } if (resource_notification.deleted_resources_size() != 0) { - ResourceDeleted(id, + ResourceDeleted(node_id, VectorFromProtobuf(resource_notification.deleted_resources())); } } From a33466e90595276ab3c3ad180d1a56398b136f21 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 2 Nov 2021 11:03:12 +0800 Subject: [PATCH 14/15] [Core] Fail inflight tasks on actor restarting (#19354) ## Why are these changes needed? If an actor failover is triggered, but the RPC connection between the caller and the crashed actor instance is not disconnected automatically, subsequent tasks to the new actor instance may not be executed. The root cause is that the sequence numbers of tasks sent to the new actor instance is not starting from 0. Details can be found in #14727. This PR fixes it by ensuring all inflight actor tasks fail immediately when actor failover is detected (via actor state notifications). ## Related issue number closes #14727 --- python/ray/tests/test_failure.py | 86 +++++- src/mock/ray/core_worker/task_manager.h | 2 +- src/ray/core_worker/task_manager.cc | 2 +- src/ray/core_worker/task_manager.h | 4 +- .../test/direct_actor_transport_test.cc | 70 ++++- .../test/direct_task_transport_test.cc | 2 +- .../transport/direct_actor_transport.cc | 256 +++++++++++------- .../transport/direct_actor_transport.h | 12 +- 8 files changed, 318 insertions(+), 116 deletions(-) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index e118fe72b897..ad34bca748db 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -1,4 +1,5 @@ import os +import signal import sys import time @@ -7,8 +8,9 @@ import ray import ray._private.utils +import ray._private.gcs_utils as gcs_utils import ray.ray_constants as ray_constants -from ray.exceptions import RayTaskError +from ray.exceptions import RayTaskError, RayActorError, GetTimeoutError from ray._private.test_utils import (wait_for_condition, SignalActor, init_error_pubsub, get_error_message) @@ -587,6 +589,88 @@ def f(): assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "num_cpus": 0, + "_system_config": { + "raylet_death_check_interval_milliseconds": 10 * 1000, + "num_heartbeats_timeout": 10, + "raylet_heartbeat_period_milliseconds": 100, + "timeout_ms_task_wait_for_death_info": 100, + } + }], + indirect=True) +def test_actor_failover_with_bad_network(ray_start_cluster_head): + # The test case is to cover the scenario that when an actor FO happens, + # the caller receives the actor ALIVE notification and connects to the new + # actor instance while there are still some tasks sent to the previous + # actor instance haven't returned. + # + # It's not easy to reproduce this scenario, so we set + # `raylet_death_check_interval_milliseconds` to a large value and add a + # never-return function for the actor to keep the RPC connection alive + # while killing the node to trigger actor failover. Later we send SIGKILL + # to kill the previous actor process to let the task fail. + # + # The expected behavior is that after the actor is alive again and the + # previous RPC connection is broken, tasks sent via the previous RPC + # connection should fail but tasks sent via the new RPC connection should + # succeed. + + cluster = ray_start_cluster_head + node = cluster.add_node(num_cpus=1) + + @ray.remote(max_restarts=1) + class Actor: + def getpid(self): + return os.getpid() + + def never_return(self): + while True: + time.sleep(1) + return 0 + + # The actor should be placed on the non-head node. + actor = Actor.remote() + pid = ray.get(actor.getpid.remote()) + + # Submit a never-return task (task 1) to the actor. The return + # object should be unready. + obj1 = actor.never_return.remote() + with pytest.raises(GetTimeoutError): + ray.get(obj1, timeout=1) + + # Kill the non-head node and start a new one. Now GCS should trigger actor + # FO. Since we changed the interval of worker checking death of Raylet, + # the actor process won't quit in a short time. + cluster.remove_node(node, allow_graceful=False) + cluster.add_node(num_cpus=1) + + # The removed node will be marked as dead by GCS after 1 second and task 1 + # will return with failure after that. + with pytest.raises(RayActorError): + ray.get(obj1, timeout=2) + + # Wait for the actor to be alive again in a new worker process. + def check_actor_restart(): + actors = list(ray.state.actors().values()) + assert len(actors) == 1 + print(actors) + return (actors[0]["State"] == gcs_utils.ActorTableData.ALIVE + and actors[0]["NumRestarts"] == 1) + + wait_for_condition(check_actor_restart) + + # Kill the previous actor process. + os.kill(pid, signal.SIGKILL) + + # Submit another task (task 2) to the actor. + obj2 = actor.getpid.remote() + + # We should be able to get the return value of task 2 without any issue + ray.get(obj2) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/src/mock/ray/core_worker/task_manager.h b/src/mock/ray/core_worker/task_manager.h index effea598da9d..cfc54a8b6a28 100644 --- a/src/mock/ray/core_worker/task_manager.h +++ b/src/mock/ray/core_worker/task_manager.h @@ -24,7 +24,7 @@ class MockTaskFinisherInterface : public TaskFinisherInterface { const rpc::Address &actor_addr), (override)); MOCK_METHOD(bool, PendingTaskFailed, - (const TaskID &task_id, rpc::ErrorType error_type, Status *status, + (const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception, bool immediately_mark_object_fail), (override)); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 5d4b27e45055..12c753257b65 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -350,7 +350,7 @@ bool TaskManager::RetryTaskIfPossible(const TaskID &task_id) { } bool TaskManager::PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception, bool immediately_mark_object_fail) { // Note that this might be the __ray_terminate__ task, so we don't log diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index c59d307fd973..f77a2c0957b6 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -34,7 +34,7 @@ class TaskFinisherInterface { virtual bool RetryTaskIfPossible(const TaskID &task_id) = 0; virtual bool PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception = nullptr, bool immediately_mark_object_fail = true) = 0; @@ -146,7 +146,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// result object as failed. /// \return Whether the task will be retried or not. bool PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status = nullptr, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status = nullptr, const std::shared_ptr &creation_task_exception = nullptr, bool immediately_mark_object_fail = true) override; diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index ea5a15b48dc6..69d8d46d9bd7 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -41,7 +41,7 @@ TaskSpecification CreateActorTaskHelper(ActorID actor_id, WorkerID caller_worker int64_t counter, TaskID caller_id = TaskID::Nil()) { TaskSpecification task; - task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); task.GetMutableMessage().set_caller_id(caller_id.Binary()); task.GetMutableMessage().set_type(TaskType::ACTOR_TASK); task.GetMutableMessage().mutable_caller_address()->set_worker_id( @@ -137,7 +137,7 @@ TEST_F(DirectActorSubmitterTest, TestSubmitTask) { ASSERT_TRUE(submitter_.SubmitTask(task).ok()); ASSERT_EQ(worker_client_->callbacks.size(), 2); - EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _, _)) + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _)) .Times(worker_client_->callbacks.size()); EXPECT_CALL(*task_finisher_, PendingTaskFailed(_, _, _, _, _)).Times(0); while (!worker_client_->callbacks.empty()) { @@ -277,10 +277,10 @@ TEST_F(DirectActorSubmitterTest, TestActorDead) { } EXPECT_CALL(*task_finisher_, PendingTaskFailed(_, _, _, _, _)).Times(0); - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Actor marked as dead. All queued tasks should get failed. EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(1); - submitter_.DisconnectActor(actor_id, 1, /*dead=*/true); + submitter_.DisconnectActor(actor_id, 2, /*dead=*/true); } TEST_F(DirectActorSubmitterTest, TestActorRestartNoRetry) { @@ -303,14 +303,16 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartNoRetry) { ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); - EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(2); - EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(2); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(1); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(1); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)).Times(1); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task4.TaskId(), _, _)).Times(1); // First task finishes. Second task fails. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); // Simulate the actor failing. - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Third task fails after the actor is disconnected. It should not get // retried. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); @@ -346,17 +348,20 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartRetry) { ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); // All tasks will eventually finish. - EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(4); + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _)).Times(4); // Tasks 2 and 3 will be retried. EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)) - .Times(2) + .Times(1) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)) + .Times(1) .WillRepeatedly(Return(true)); // First task finishes. Second task fails. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); // Simulate the actor failing. - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Third task fails after the actor is disconnected. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); @@ -395,7 +400,7 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderRetry) { ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); // All tasks will eventually finish. - EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(3); + EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _, _)).Times(3); // Tasks 2 will be retried EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)) @@ -406,7 +411,7 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderRetry) { ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK(), /*index=*/1)); // Simulate the actor failing. ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""), /*index=*/0)); - submitter_.DisconnectActor(actor_id, 0, /*dead=*/false); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); // Actor gets restarted. addr.set_port(1); @@ -493,6 +498,47 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderGcs) { ASSERT_TRUE(submitter_.SubmitTask(task).ok()); } +TEST_F(DirectActorSubmitterTest, TestActorRestartFailInflightTasks) { + rpc::Address addr; + auto worker_id = WorkerID::FromRandom(); + addr.set_worker_id(worker_id.Binary()); + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + submitter_.AddActorQueueIfNotExists(actor_id); + addr.set_port(0); + submitter_.ConnectActor(actor_id, addr, 0); + ASSERT_EQ(worker_client_->callbacks.size(), 0); + ASSERT_EQ(num_clients_connected_, 1); + + // Create 3 tasks for the actor. + auto task1 = CreateActorTaskHelper(actor_id, worker_id, 0); + auto task2 = CreateActorTaskHelper(actor_id, worker_id, 1); + auto task3 = CreateActorTaskHelper(actor_id, worker_id, 1); + // Submit a task. + ASSERT_TRUE(submitter_.SubmitTask(task1).ok()); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task1.TaskId(), _, _)).Times(1); + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); + + // Submit 2 tasks. + ASSERT_TRUE(submitter_.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter_.SubmitTask(task3).ok()); + // Actor failed, but the task replies are delayed (or in some scenarios, lost). + // We should still be able to fail the inflight tasks. + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(1); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)).Times(1); + submitter_.DisconnectActor(actor_id, 1, /*dead=*/false); + + // The task replies are now received. Since the tasks are already failed, they will not + // be marked as failed or finished again. + EXPECT_CALL(*task_finisher_, CompletePendingTask(task2.TaskId(), _, _)).Times(0); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task2.TaskId(), _, _, _, _)).Times(0); + EXPECT_CALL(*task_finisher_, CompletePendingTask(task3.TaskId(), _, _)).Times(0); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(task3.TaskId(), _, _, _, _)).Times(0); + // Task 2 replied with OK. + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::OK())); + // Task 3 replied with error. + ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); +} + class MockDependencyWaiter : public DependencyWaiter { public: MOCK_METHOD2(Wait, void(const std::vector &dependencies, diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index b631b1d37217..2db2ab426fc5 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -110,7 +110,7 @@ class MockTaskFinisher : public TaskFinisherInterface { } bool PendingTaskFailed( - const TaskID &task_id, rpc::ErrorType error_type, Status *status, + const TaskID &task_id, rpc::ErrorType error_type, const Status *status, const std::shared_ptr &creation_task_exception = nullptr, bool immediately_mark_object_fail = true) override { num_tasks_failed++; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index e04421f080d9..d085884729b8 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -134,121 +134,161 @@ void CoreWorkerDirectActorTaskSubmitter::DisconnectRpcClient(ClientQueue &queue) queue.pending_force_kill.reset(); } +void CoreWorkerDirectActorTaskSubmitter::FailInflightTasks( + const std::unordered_map> + &inflight_task_callbacks) { + // NOTE(kfstorm): We invoke the callbacks with a bad status to act like there's a + // network issue. We don't call `task_finisher_.PendingTaskFailed` directly because + // there's much more work to do in the callback. + auto status = Status::IOError("Fail all inflight tasks due to actor state change."); + rpc::PushTaskReply reply; + for (const auto &entry : inflight_task_callbacks) { + entry.second(status, reply); + } +} + void CoreWorkerDirectActorTaskSubmitter::ConnectActor(const ActorID &actor_id, const rpc::Address &address, int64_t num_restarts) { RAY_LOG(DEBUG) << "Connecting to actor " << actor_id << " at worker " << WorkerID::FromBinary(address.worker_id()); - absl::MutexLock lock(&mu_); - auto queue = client_queues_.find(actor_id); - RAY_CHECK(queue != client_queues_.end()); - if (num_restarts < queue->second.num_restarts) { - // This message is about an old version of the actor and the actor has - // already restarted since then. Skip the connection. - RAY_LOG(INFO) << "Skip actor connection that has already been restarted, actor_id=" - << actor_id; - return; - } + std::unordered_map> + inflight_task_callbacks; - if (queue->second.rpc_client && - queue->second.rpc_client->Addr().ip_address() == address.ip_address() && - queue->second.rpc_client->Addr().port() == address.port()) { - RAY_LOG(DEBUG) << "Skip actor that has already been connected, actor_id=" << actor_id; - return; - } + { + absl::MutexLock lock(&mu_); - if (queue->second.state == rpc::ActorTableData::DEAD) { - // This message is about an old version of the actor and the actor has - // already died since then. Skip the connection. - return; - } + auto queue = client_queues_.find(actor_id); + RAY_CHECK(queue != client_queues_.end()); + if (num_restarts < queue->second.num_restarts) { + // This message is about an old version of the actor and the actor has + // already restarted since then. Skip the connection. + RAY_LOG(INFO) << "Skip actor connection that has already been restarted, actor_id=" + << actor_id; + return; + } - queue->second.num_restarts = num_restarts; - if (queue->second.rpc_client) { - // Clear the client to the old version of the actor. - DisconnectRpcClient(queue->second); + if (queue->second.rpc_client && + queue->second.rpc_client->Addr().ip_address() == address.ip_address() && + queue->second.rpc_client->Addr().port() == address.port()) { + RAY_LOG(DEBUG) << "Skip actor that has already been connected, actor_id=" + << actor_id; + return; + } + + if (queue->second.state == rpc::ActorTableData::DEAD) { + // This message is about an old version of the actor and the actor has + // already died since then. Skip the connection. + return; + } + + queue->second.num_restarts = num_restarts; + if (queue->second.rpc_client) { + // Clear the client to the old version of the actor. + DisconnectRpcClient(queue->second); + inflight_task_callbacks = std::move(queue->second.inflight_task_callbacks); + queue->second.inflight_task_callbacks.clear(); + } + + queue->second.state = rpc::ActorTableData::ALIVE; + // Update the mapping so new RPCs go out with the right intended worker id. + queue->second.worker_id = address.worker_id(); + // Create a new connection to the actor. + queue->second.rpc_client = core_worker_client_pool_.GetOrConnect(address); + // This assumes that all replies from the previous incarnation + // of the actor have been received. This assumption should be OK + // because we fail all inflight tasks in `DisconnectRpcClient`. + RAY_LOG(DEBUG) << "Resetting caller starts at for actor " << actor_id << " from " + << queue->second.caller_starts_at << " to " + << queue->second.next_task_reply_position; + queue->second.caller_starts_at = queue->second.next_task_reply_position; + + RAY_LOG(INFO) << "Connecting to actor " << actor_id << " at worker " + << WorkerID::FromBinary(address.worker_id()); + ResendOutOfOrderTasks(actor_id); + SendPendingTasks(actor_id); } - queue->second.state = rpc::ActorTableData::ALIVE; - // Update the mapping so new RPCs go out with the right intended worker id. - queue->second.worker_id = address.worker_id(); - // Create a new connection to the actor. - queue->second.rpc_client = core_worker_client_pool_.GetOrConnect(address); - // TODO(swang): This assumes that all replies from the previous incarnation - // of the actor have been received. Fix this by setting an epoch for each - // actor task, so we can ignore completed tasks from old epochs. - RAY_LOG(DEBUG) << "Resetting caller starts at for actor " << actor_id << " from " - << queue->second.caller_starts_at << " to " - << queue->second.next_task_reply_position; - queue->second.caller_starts_at = queue->second.next_task_reply_position; - - RAY_LOG(INFO) << "Connecting to actor " << actor_id << " at worker " - << WorkerID::FromBinary(address.worker_id()); - ResendOutOfOrderTasks(actor_id); - SendPendingTasks(actor_id); + // NOTE(kfstorm): We need to make sure the lock is released before invoking callbacks. + FailInflightTasks(inflight_task_callbacks); } void CoreWorkerDirectActorTaskSubmitter::DisconnectActor( const ActorID &actor_id, int64_t num_restarts, bool dead, const std::shared_ptr &creation_task_exception) { RAY_LOG(DEBUG) << "Disconnecting from actor " << actor_id; - absl::MutexLock lock(&mu_); - auto queue = client_queues_.find(actor_id); - RAY_CHECK(queue != client_queues_.end()); - if (num_restarts <= queue->second.num_restarts && !dead) { - // This message is about an old version of the actor that has already been - // restarted successfully. Skip the message handling. - RAY_LOG(INFO) << "Skip actor disconnection that has already been restarted, actor_id=" - << actor_id; - return; - } - // The actor failed, so erase the client for now. Either the actor is - // permanently dead or the new client will be inserted once the actor is - // restarted. - DisconnectRpcClient(queue->second); - - if (dead) { - queue->second.state = rpc::ActorTableData::DEAD; - queue->second.creation_task_exception = creation_task_exception; - // If there are pending requests, treat the pending tasks as failed. - RAY_LOG(INFO) << "Failing pending tasks for actor " << actor_id - << " because the actor is already dead."; - auto &requests = queue->second.requests; - auto head = requests.begin(); - - auto status = Status::IOError("cancelling all pending tasks of dead actor"); - while (head != requests.end()) { - const auto &task_spec = head->second.first; - task_finisher_.MarkTaskCanceled(task_spec.TaskId()); - // No need to increment the number of completed tasks since the actor is - // dead. - RAY_UNUSED(!task_finisher_.PendingTaskFailed(task_spec.TaskId(), - rpc::ErrorType::ACTOR_DIED, &status, - creation_task_exception)); - head = requests.erase(head); + std::unordered_map> + inflight_task_callbacks; + + { + absl::MutexLock lock(&mu_); + auto queue = client_queues_.find(actor_id); + RAY_CHECK(queue != client_queues_.end()); + if (!dead) { + RAY_CHECK(num_restarts > 0); + } + if (num_restarts <= queue->second.num_restarts && !dead) { + // This message is about an old version of the actor that has already been + // restarted successfully. Skip the message handling. + RAY_LOG(INFO) + << "Skip actor disconnection that has already been restarted, actor_id=" + << actor_id; + return; } - auto &wait_for_death_info_tasks = queue->second.wait_for_death_info_tasks; + // The actor failed, so erase the client for now. Either the actor is + // permanently dead or the new client will be inserted once the actor is + // restarted. + DisconnectRpcClient(queue->second); + inflight_task_callbacks = std::move(queue->second.inflight_task_callbacks); + queue->second.inflight_task_callbacks.clear(); + + if (dead) { + queue->second.state = rpc::ActorTableData::DEAD; + queue->second.creation_task_exception = creation_task_exception; + // If there are pending requests, treat the pending tasks as failed. + RAY_LOG(INFO) << "Failing pending tasks for actor " << actor_id + << " because the actor is already dead."; + auto &requests = queue->second.requests; + auto head = requests.begin(); + + auto status = Status::IOError("cancelling all pending tasks of dead actor"); + while (head != requests.end()) { + const auto &task_spec = head->second.first; + task_finisher_.MarkTaskCanceled(task_spec.TaskId()); + // No need to increment the number of completed tasks since the actor is + // dead. + RAY_UNUSED(!task_finisher_.PendingTaskFailed(task_spec.TaskId(), + rpc::ErrorType::ACTOR_DIED, &status, + creation_task_exception)); + head = requests.erase(head); + } - RAY_LOG(INFO) << "Failing tasks waiting for death info, size=" - << wait_for_death_info_tasks.size() << ", actor_id=" << actor_id; - for (auto &net_err_task : wait_for_death_info_tasks) { - RAY_UNUSED(task_finisher_.MarkPendingTaskFailed( - net_err_task.second, rpc::ErrorType::ACTOR_DIED, creation_task_exception)); - } + auto &wait_for_death_info_tasks = queue->second.wait_for_death_info_tasks; - // No need to clean up tasks that have been sent and are waiting for - // replies. They will be treated as failed once the connection dies. - // We retain the sequencing information so that we can properly fail - // any tasks submitted after the actor death. - } else if (queue->second.state != rpc::ActorTableData::DEAD) { - // Only update the actor's state if it is not permanently dead. The actor - // will eventually get restarted or marked as permanently dead. - queue->second.state = rpc::ActorTableData::RESTARTING; - queue->second.num_restarts = num_restarts; + RAY_LOG(INFO) << "Failing tasks waiting for death info, size=" + << wait_for_death_info_tasks.size() << ", actor_id=" << actor_id; + for (auto &net_err_task : wait_for_death_info_tasks) { + RAY_UNUSED(task_finisher_.MarkPendingTaskFailed( + net_err_task.second, rpc::ErrorType::ACTOR_DIED, creation_task_exception)); + } + + // No need to clean up tasks that have been sent and are waiting for + // replies. They will be treated as failed once the connection dies. + // We retain the sequencing information so that we can properly fail + // any tasks submitted after the actor death. + } else if (queue->second.state != rpc::ActorTableData::DEAD) { + // Only update the actor's state if it is not permanently dead. The actor + // will eventually get restarted or marked as permanently dead. + queue->second.state = rpc::ActorTableData::RESTARTING; + queue->second.num_restarts = num_restarts; + } } + + // NOTE(kfstorm): We need to make sure the lock is released before invoking callbacks. + FailInflightTasks(inflight_task_callbacks); } void CoreWorkerDirectActorTaskSubmitter::CheckTimeoutTasks() { @@ -319,7 +359,7 @@ void CoreWorkerDirectActorTaskSubmitter::ResendOutOfOrderTasks(const ActorID &ac client_queue.out_of_order_completed_tasks.clear(); } -void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue, +void CoreWorkerDirectActorTaskSubmitter::PushActorTask(ClientQueue &queue, const TaskSpecification &task_spec, bool skip_queue) { auto request = std::make_unique(); @@ -349,10 +389,9 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue, } rpc::Address addr(queue.rpc_client->Addr()); - queue.rpc_client->PushActorTask( - std::move(request), skip_queue, + rpc::ClientCallback reply_callback = [this, addr, task_id, actor_id, actor_counter, task_spec, task_skipped]( - Status status, const rpc::PushTaskReply &reply) { + const Status &status, const rpc::PushTaskReply &reply) { bool increment_completed_tasks = true; if (task_skipped) { @@ -420,7 +459,30 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask(const ClientQueue &queue, << " and size of out_of_order_tasks set is " << queue.out_of_order_completed_tasks.size(); } - }); + }; + + queue.inflight_task_callbacks.emplace(task_id, std::move(reply_callback)); + rpc::ClientCallback wrapped_callback = + [this, task_id, actor_id](const Status &status, const rpc::PushTaskReply &reply) { + rpc::ClientCallback reply_callback; + { + absl::MutexLock lock(&mu_); + auto it = client_queues_.find(actor_id); + RAY_CHECK(it != client_queues_.end()); + auto &queue = it->second; + auto callback_it = queue.inflight_task_callbacks.find(task_id); + if (callback_it == queue.inflight_task_callbacks.end()) { + RAY_LOG(DEBUG) << "The task " << task_id + << " has already been marked as failed. Ingore the reply."; + return; + } + reply_callback = std::move(callback_it->second); + queue.inflight_task_callbacks.erase(callback_it); + } + reply_callback(status, reply); + }; + + queue.rpc_client->PushActorTask(std::move(request), skip_queue, wrapped_callback); } bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) const { diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 42a048ac2f16..162e9dd52b49 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -223,6 +223,11 @@ class CoreWorkerDirectActorTaskSubmitter /// A force-kill request that should be sent to the actor once an RPC /// client to the actor is available. absl::optional pending_force_kill; + + /// Stores all callbacks of inflight tasks. Note that this doesn't include tasks + /// without replies. + std::unordered_map> + inflight_task_callbacks; }; /// Push a task to a remote actor via the given client. @@ -234,7 +239,7 @@ class CoreWorkerDirectActorTaskSubmitter /// \param[in] skip_queue Whether to skip the task queue. This will send the /// task for execution immediately. /// \return Void. - void PushActorTask(const ClientQueue &queue, const TaskSpecification &task_spec, + void PushActorTask(ClientQueue &queue, const TaskSpecification &task_spec, bool skip_queue) EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Send all pending tasks for an actor. @@ -253,6 +258,11 @@ class CoreWorkerDirectActorTaskSubmitter /// Disconnect the RPC client for an actor. void DisconnectRpcClient(ClientQueue &queue) EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Fail all in-flight tasks. + void FailInflightTasks( + const std::unordered_map> + &inflight_task_callbacks) LOCKS_EXCLUDED(mu_); + /// Whether the specified actor is alive. /// /// \param[in] actor_id The actor ID. From da6894848d538cdc26a1ca82bf91436a3cde1828 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 2 Nov 2021 11:05:40 +0800 Subject: [PATCH 15/15] Support Java namespace APIs (#19468) ## Why are these changes needed? ## Related issue number #16474 --- .../java/io/ray/runtime/RayNativeRuntime.java | 3 +- .../java/io/ray/runtime/config/RayConfig.java | 5 ++ .../src/main/resources/ray.default.conf | 4 ++ .../java/io/ray/test/MultiDriverTest.java | 2 - .../main/java/io/ray/test/NamespaceTest.java | 71 +++++++++++++++++++ 5 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 java/test/src/main/java/io/ray/test/NamespaceTest.java diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 1d8d60e1048f..af90153a506b 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -106,7 +106,8 @@ public void start() { JobConfig.newBuilder() .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) - .addAllCodeSearchPath(rayConfig.codeSearchPath); + .addAllCodeSearchPath(rayConfig.codeSearchPath) + .setRayNamespace(rayConfig.namespace); RuntimeEnv.Builder runtimeEnvBuilder = RuntimeEnv.newBuilder(); if (!rayConfig.workerEnv.isEmpty()) { // TODO(SongGuyang): Suppport complete runtime env interface for users. diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index 511489b0d2f9..d651b08f6607 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -69,6 +69,8 @@ public LoggerConf(String loggerName, String fileName, String pattern) { public final int numWorkersPerProcess; + public final String namespace; + public final List jvmOptionsForJavaWorker; public final Map workerEnv; @@ -118,6 +120,9 @@ public RayConfig(Config config) { this.jobId = JobId.NIL; } + // Namespace of this job. + namespace = config.getString("ray.job.namespace"); + // jvm options for java workers of this job. jvmOptionsForJavaWorker = config.getStringList("ray.job.jvm-options"); diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 2c865530e5ce..0d94d72e30b3 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -36,6 +36,10 @@ ray { // key1 : "value1" // key2 : "value2" } + /// The namespace of this job. It's used for isolation between jobs. + /// Jobs in different namespaces cannot access each other. + /// If it's not specified, a randomized value will be used instead. + namespace: "" } // Configurations about raylet diff --git a/java/test/src/main/java/io/ray/test/MultiDriverTest.java b/java/test/src/main/java/io/ray/test/MultiDriverTest.java index cd2057279467..99425989e323 100644 --- a/java/test/src/main/java/io/ray/test/MultiDriverTest.java +++ b/java/test/src/main/java/io/ray/test/MultiDriverTest.java @@ -3,7 +3,6 @@ import io.ray.api.ActorHandle; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.runtime.config.RayConfig; import io.ray.runtime.util.SystemUtil; import java.io.BufferedReader; import java.io.IOException; @@ -102,7 +101,6 @@ public void testMultiDrivers() throws InterruptedException, IOException { } private Process startDriver() throws IOException { - RayConfig rayConfig = TestUtils.getRuntime().getRayConfig(); ProcessBuilder builder = TestUtils.buildDriver(MultiDriverTest.class, null); builder.redirectError(Redirect.INHERIT); return builder.start(); diff --git a/java/test/src/main/java/io/ray/test/NamespaceTest.java b/java/test/src/main/java/io/ray/test/NamespaceTest.java new file mode 100644 index 000000000000..1aa0ca6f5e52 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/NamespaceTest.java @@ -0,0 +1,71 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.concurrent.TimeUnit; +import org.testng.Assert; +import org.testng.annotations.Test; + +@Test(groups = "cluster") +public class NamespaceTest { + + private static class A { + public String hello() { + return "hello"; + } + } + + /// This case tests that actor cannot be accessed in different namespaces. + public void testIsolationBetweenNamespaces() throws IOException, InterruptedException { + System.setProperty("ray.job.namespace", "test2"); + testIsolation( + () -> + Assert.assertThrows( + NoSuchElementException.class, + () -> { + Ray.getGlobalActor("a").get(); + })); + } + + /// This case tests that actor can be accessed between different jobs but in the same namespace. + public void testIsolationInTheSameNamespaces() throws IOException, InterruptedException { + System.setProperty("ray.job.namespace", "test1"); + testIsolation( + () -> { + ActorHandle a = (ActorHandle) Ray.getGlobalActor("a").get(); + Assert.assertEquals("hello", a.task(A::hello).remote().get()); + }); + } + + public static void main(String[] args) throws IOException, InterruptedException { + System.setProperty("ray.job.namespace", "test1"); + Ray.init(); + ActorHandle a = Ray.actor(A::new).setGlobalName("a").remote(); + Assert.assertEquals("hello", a.task(A::hello).remote().get()); + /// Because we don't support long running job yet, so sleep to don't destroy + /// it for a while. Otherwise the actor created in this job will be destroyed + /// as well. + TimeUnit.SECONDS.sleep(10); + Ray.shutdown(); + } + + private void testIsolation(Runnable runnable) throws IOException, InterruptedException { + Process driver = null; + try { + Ray.init(); + ProcessBuilder builder = TestUtils.buildDriver(NamespaceTest.class, null); + builder.redirectError(ProcessBuilder.Redirect.INHERIT); + driver = builder.start(); + // Wait for driver to start. + TimeUnit.SECONDS.sleep(3); + runnable.run(); + } finally { + if (driver != null) { + driver.waitFor(1, TimeUnit.SECONDS); + } + Ray.shutdown(); + } + } +}