diff --git a/BUILD.bazel b/BUILD.bazel index 61484ba82f6b..e2cbdd64bf51 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -495,6 +495,7 @@ flatbuffer_py_library( "ConfigTableData.py", "CustomSerializerData.py", "DriverTableData.py", + "EntryType.py", "ErrorTableData.py", "ErrorType.py", "FunctionTableData.py", diff --git a/README.rst b/README.rst index 203c69ffc838..06dd8115fdf3 100644 --- a/README.rst +++ b/README.rst @@ -6,12 +6,12 @@ .. image:: https://readthedocs.org/projects/ray/badge/?version=latest :target: http://ray.readthedocs.io/en/latest/?badge=latest -.. image:: https://img.shields.io/badge/pypi-0.6.6-blue.svg +.. image:: https://img.shields.io/badge/pypi-0.7.0-blue.svg :target: https://pypi.org/project/ray/ | -**Ray is a flexible, high-performance distributed execution framework.** +**Ray is a fast and simple framework for building and running distributed applications.** Ray is easy to install: ``pip install ray`` diff --git a/bazel/ray.bzl b/bazel/ray.bzl index 4ba637f3cdd4..750b90a21aec 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -25,8 +25,10 @@ def flatbuffer_java_library(name, srcs, outs, out_prefix, includes = [], include ) def define_java_module(name, additional_srcs = [], additional_resources = [], define_test_lib = False, test_deps = [], **kwargs): + lib_name = "org_ray_ray_" + name + pom_file_targets = [lib_name] native.java_library( - name = "org_ray_ray_" + name, + name = lib_name, srcs = additional_srcs + native.glob([name + "/src/main/java/**/*.java"]), resources = native.glob([name + "/src/main/resources/**"]) + additional_resources, **kwargs @@ -40,8 +42,10 @@ def define_java_module(name, additional_srcs = [], additional_resources = [], de tags = ["checkstyle"], ) if define_test_lib: + test_lib_name = "org_ray_ray_" + name + "_test" + pom_file_targets.append(test_lib_name) native.java_library( - name = "org_ray_ray_" + name + "_test", + name = test_lib_name, srcs = native.glob([name + "/src/test/java/**/*.java"]), deps = test_deps, ) @@ -53,12 +57,11 @@ def define_java_module(name, additional_srcs = [], additional_resources = [], de size = "small", tags = ["checkstyle"], ) - -def gen_java_pom_file(name): pom_file( name = "org_ray_ray_" + name + "_pom", - targets = [ - ":org_ray_ray_" + name, - ], + targets = pom_file_targets, template_file = name + "/pom_template.xml", + substitutions = { + "{auto_gen_header}": "", + }, ) diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 6094ed3c9303..dafa72b773fe 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -80,18 +80,18 @@ def ray_deps_setup(): http_archive( name = "io_opencensus_cpp", - strip_prefix = "opencensus-cpp-0.3.0", - urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/v0.3.0.zip"], + strip_prefix = "opencensus-cpp-3aa11f20dd610cb8d2f7c62e58d1e69196aadf11", + urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/3aa11f20dd610cb8d2f7c62e58d1e69196aadf11.zip"], ) # OpenCensus depends on Abseil so we have to explicitly pull it in. # This is how diamond dependencies are prevented. git_repository( name = "com_google_absl", - commit = "88a152ae747c3c42dc9167d46c590929b048d436", + commit = "5b65c4af5107176555b23a638e5947686410ac1f", remote = "https://github.com/abseil/abseil-cpp.git", ) - + # OpenCensus depends on jupp0r/prometheus-cpp http_archive( name = "com_github_jupp0r_prometheus_cpp", diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index 927f8bf5e83d..f723d5122981 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev3-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index efe30a0a7780..13acff28d39c 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_dependency.py + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py @@ -365,6 +368,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 --simple + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2 diff --git a/ci/long_running_tests/config.yaml b/ci/long_running_tests/config.yaml index cbc7feb435af..b9667ae648bb 100644 --- a/ci/long_running_tests/config.yaml +++ b/ci/long_running_tests/config.yaml @@ -49,6 +49,7 @@ setup_commands: # - sudo apt-get update # - sudo apt-get install -y build-essential curl unzip # - git clone https://github.com/ray-project/ray || true + # - ray/ci/travis/install-bazel.sh # - cd ray/python; git checkout master; git pull; pip install -e . --verbose # Install nightly Ray wheels. - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/<>/ray-<>-cp36-cp36m-manylinux1_x86_64.whl diff --git a/ci/stress_tests/application_cluster_template.yaml b/ci/stress_tests/application_cluster_template.yaml index e8fc0efa2bcd..541419da55af 100644 --- a/ci/stress_tests/application_cluster_template.yaml +++ b/ci/stress_tests/application_cluster_template.yaml @@ -90,8 +90,8 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_<<>>/bin:$PATH"' >> ~/.bashrc - - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-<<>>-manylinux1_x86_64.whl - - rllib || pip install -U ray-0.7.0.dev2-<<>>-manylinux1_x86_64.whl[rllib] + - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl + - rllib || pip install -U ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl[rllib] - pip install tensorflow-gpu==1.12.0 - echo "sudo halt" | at now + 60 minutes # Consider uncommenting these if you also want to run apt-get commands during setup diff --git a/ci/stress_tests/stress_testing_config.yaml b/ci/stress_tests/stress_testing_config.yaml index 3ea6f7f717a3..f71ae8f2dc18 100644 --- a/ci/stress_tests/stress_testing_config.yaml +++ b/ci/stress_tests/stress_testing_config.yaml @@ -98,9 +98,10 @@ setup_commands: - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc # # Build Ray. # - git clone https://github.com/ray-project/ray || true + # - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl - echo "sudo halt" | at now + 60 minutes # Custom commands that will be run on the head node after common setup. diff --git a/ci/stress_tests/test_many_tasks_and_transfers.py b/ci/stress_tests/test_many_tasks_and_transfers.py index c4c4825b1bb9..985e05c28a02 100644 --- a/ci/stress_tests/test_many_tasks_and_transfers.py +++ b/ci/stress_tests/test_many_tasks_and_transfers.py @@ -24,10 +24,10 @@ # Wait until the expected number of nodes have joined the cluster. while True: - if len(ray.global_state.client_table()) >= num_remote_nodes + 1: + if len(ray.nodes()) >= num_remote_nodes + 1: break logger.info("Nodes have all joined. There are %s resources.", - ray.global_state.cluster_resources()) + ray.cluster_resources()) # Require 1 GPU to force the tasks to be on remote machines. diff --git a/dev/RELEASE_PROCESS.rst b/dev/RELEASE_PROCESS.rst index 88a243c28bb6..62862506e1ed 100644 --- a/dev/RELEASE_PROCESS.rst +++ b/dev/RELEASE_PROCESS.rst @@ -41,12 +41,12 @@ This document describes the process for creating new releases. 6. **Download all the wheels:** Now the release is ready to begin final testing. The wheels are automatically uploaded to S3, even on the release - branch. The wheels can ``pip install``ed from the following URLs: + branch. To test, ``pip install`` from the following URLs: .. code-block:: bash export RAY_HASH=... # e.g., 618147f57fb40368448da3b2fb4fd213828fa12b - export RAY_VERSION=... # e.g., 0.6.6 + export RAY_VERSION=... # e.g., 0.7.0 pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27mu-manylinux1_x86_64.whl pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-manylinux1_x86_64.whl pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-manylinux1_x86_64.whl @@ -120,10 +120,28 @@ This document describes the process for creating new releases. git pull origin master --tags git log $(git describe --tags --abbrev=0)..HEAD --pretty=format:"%s" | sort +11. **Bump version on Ray master branch:** Create a pull request to increment the + version of the master branch. The format of the new version is as follows: + + New minor release (e.g., 0.7.0): Increment the minor version and append ``.dev0`` to + the version. For example, if the version of the new release is 0.7.0, the master + branch needs to be updated to 0.8.0.dev0. `Example PR for minor release` + + New micro release (e.g., 0.7.1): Increment the ``dev`` number, such that the number + after ``dev`` equals the micro version. For example, if the version of the new + release is 0.7.1, the master branch needs to be updated to 0.8.0.dev1. + +12. **Update version numbers throughout codebase:** Suppose we just released 0.7.1. The + previous release version number (in this case 0.7.0) and the previous dev version number + (in this case 0.8.0.dev0) appear in many places throughout the code base including + the installation documentation, the example autoscaler config files, and the testing + scripts. Search for all of the occurrences of these version numbers and update them to + use the new release and dev version numbers. + .. _documentation: https://ray.readthedocs.io/en/latest/installation.html#trying-snapshots-from-master .. _`documentation for building wheels`: https://github.com/ray-project/ray/blob/master/python/README-building-wheels.md .. _`ci/stress_tests/run_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_stress_tests.sh .. _`ci/stress_tests/run_application_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_application_stress_tests.sh .. _`this example`: https://github.com/ray-project/ray/pull/4226 -.. _`these wheels here`: https://ray.readthedocs.io/en/latest/installation.html .. _`GitHub website`: https://github.com/ray-project/ray/releases +.. _`Example PR for minor release`: https://github.com/ray-project/ray/pull/4845 diff --git a/doc/source/api.rst b/doc/source/api.rst index 65e31e5a4ded..a149fbb5bb77 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -27,6 +27,26 @@ The Ray API .. autofunction:: ray.method +Inspect the Cluster State +------------------------- + +.. autofunction:: ray.nodes() + +.. autofunction:: ray.tasks() + +.. autofunction:: ray.objects() + +.. autofunction:: ray.timeline() + +.. autofunction:: ray.object_transfer_timeline() + +.. autofunction:: ray.cluster_resources() + +.. autofunction:: ray.available_resources() + +.. autofunction:: ray.errors() + + The Ray Command Line API ------------------------ diff --git a/doc/source/conf.py b/doc/source/conf.py index b67dbe267d4c..e0bd2c6dad4c 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -26,6 +26,7 @@ "ray.core.generated.ActorCheckpointIdData", "ray.core.generated.ClientTableData", "ray.core.generated.DriverTableData", + "ray.core.generated.EntryType", "ray.core.generated.ErrorTableData", "ray.core.generated.ErrorType", "ray.core.generated.GcsTableEntry", diff --git a/doc/source/development.rst b/doc/source/development.rst index 66e666b4d1a4..1fdc65fa35cf 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -60,7 +60,7 @@ Python script with the following: .. code-block:: bash RAY_RAYLET_GDB=1 RAY_RAYLET_TMUX=1 python - + You can then list the ``tmux`` sessions with ``tmux ls`` and attach to the appropriate one. @@ -71,17 +71,17 @@ allow core dump files to be written. Inspecting Redis shards ~~~~~~~~~~~~~~~~~~~~~~~ -To inspect Redis, you can use the ``ray.experimental.state.GlobalState`` Python -API. The easiest way to do this is to start or connect to a Ray cluster with -``ray.init()``, then query the API like so: +To inspect Redis, you can use the global state API. The easiest way to do this +is to start or connect to a Ray cluster with ``ray.init()``, then query the API +like so: .. code-block:: python ray.init() - ray.worker.global_state.client_table() + ray.nodes() # Returns current information about the nodes in the cluster, such as: # [{'ClientID': '2a9d2b34ad24a37ed54e4fcd32bf19f915742f5b', - # 'IsInsertion': True, + # 'EntryType': 0, # 'NodeManagerAddress': '1.2.3.4', # 'NodeManagerPort': 43280, # 'ObjectManagerPort': 38062, diff --git a/doc/source/index.rst b/doc/source/index.rst index 48c0c0d0e662..a90e0224bb02 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -7,7 +7,7 @@ Ray Fork me on GitHub -*Ray is a flexible, high-performance distributed execution framework.* +*Ray is a fast and simple framework for building and running distributed applications.* Ray is easy to install: ``pip install ray`` @@ -98,10 +98,10 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin rllib-models.rst rllib-algorithms.rst rllib-offline.rst - rllib-dev.rst rllib-concepts.rst - rllib-package-ref.rst rllib-examples.rst + rllib-dev.rst + rllib-package-ref.rst .. toctree:: :maxdepth: 1 diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 0aa925d47c63..ad92cb347e83 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -33,14 +33,14 @@ Here are links to the latest wheels (which are built off of master). To install =================== =================== -.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp37-cp37m-manylinux1_x86_64.whl -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp37-cp37m-macosx_10_6_intel.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-manylinux1_x86_64.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-macosx_10_6_intel.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source diff --git a/doc/source/internals-overview.rst b/doc/source/internals-overview.rst index 7a762b6bdff0..109b923ecc3b 100644 --- a/doc/source/internals-overview.rst +++ b/doc/source/internals-overview.rst @@ -66,32 +66,6 @@ listens for the addition of remote functions to the centralized control state. When a new remote function is added, the thread fetches the pickled remote function, unpickles it, and can then execute that function. -Notes and limitations -~~~~~~~~~~~~~~~~~~~~~ - -- Because we export remote functions as soon as they are defined, that means - that remote functions can't close over variables that are defined after the - remote function is defined. For example, the following code gives an error. - - .. code-block:: python - - @ray.remote - def f(x): - return helper(x) - - def helper(x): - return x + 1 - - If you call ``f.remote(0)``, it will give an error of the form. - - .. code-block:: python - - Traceback (most recent call last): - File "", line 3, in f - NameError: name 'helper' is not defined - - On the other hand, if ``helper`` is defined before ``f``, then it will work. - Calling a remote function ------------------------- diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 9ee6108476a8..5a07280e3972 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -95,7 +95,7 @@ Asynchronous Proximal Policy Optimization (APPO) `[implementation] `__ We include an asynchronous variant of Proximal Policy Optimization (PPO) based on the IMPALA architecture. This is similar to IMPALA but using a surrogate policy loss with clipping. Compared to synchronous PPO, APPO is more efficient in wall-clock time due to its use of asynchronous sampling. Using a clipped loss also allows for multiple SGD passes, and therefore the potential for better sample efficiency compared to IMPALA. V-trace can also be enabled to correct for off-policy samples. -This implementation is currently *experimental*. Consider also using `PPO `__ or `IMPALA `__. +APPO is not always more efficient; it is often better to simply use `PPO `__ or `IMPALA `__. Tuned examples: `PongNoFrameskip-v4 `__ @@ -274,7 +274,7 @@ QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) --------------------------------------------------- `[paper] `__ `[implementation] `__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping `__ in the environment (see the `two-step game example `__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X. -Q-Mix is implemented in `PyTorch `__ and is currently *experimental*. +Q-Mix is implemented in `PyTorch `__ and is currently *experimental*. Tuned examples: `Two-step game `__ diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index d91a29f28b9f..06e890832295 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -1,26 +1,27 @@ -RLlib Concepts -============== +RLlib Concepts and Building Custom Algorithms +============================================= This page describes the internal concepts used to implement algorithms in RLlib. You might find this useful if modifying or adding new algorithms to RLlib. -Policy Graphs -------------- +Policies +-------- -Policy graph classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. +Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. -Most interaction with deep learning frameworks is isolated to the `PolicyGraph interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policy graphs, RLlib includes `Tensorflow `__ and `PyTorch-specific `__ templates. You can also write your own from scratch. Here is an example: +Most interaction with deep learning frameworks is isolated to the `Policy interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates. You can also write your own from scratch. Here is an example: .. code-block:: python - class CustomPolicy(PolicyGraph): - """Example of a custom policy graph written from scratch. + class CustomPolicy(Policy): + """Example of a custom policy written from scratch. - You might find it more convenient to extend TF/TorchPolicyGraph instead - for a real policy. + You might find it more convenient to use the `build_tf_policy` and + `build_torch_policy` helpers instead for a real policy, which are + described in the next sections. """ def __init__(self, observation_space, action_space, config): - PolicyGraph.__init__(self, observation_space, action_space, config) + Policy.__init__(self, observation_space, action_space, config) # example parameter self.w = 1.0 @@ -45,61 +46,434 @@ Most interaction with deep learning frameworks is isolated to the `PolicyGraph i def set_weights(self, weights): self.w = weights["w"] + +The above basic policy, when run, will produce batches of observations with the basic ``obs``, ``new_obs``, ``actions``, ``rewards``, ``dones``, and ``infos`` columns. There are two more mechanisms to pass along and emit extra information: + +**Policy recurrent state**: Suppose you want to compute actions based on the current timestep of the episode. While it is possible to have the environment provide this as part of the observation, we can instead compute and store it as part of the Policy recurrent state: + +.. code-block:: python + + def get_initial_state(self): + """Returns initial RNN state for the current policy.""" + return [0] # list of single state element (t=0) + # you could also return multiple values, e.g., [0, "foo"] + + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + assert len(state_batches) == len(self.get_initial_state()) + new_state_batches = [[ + t + 1 for t in state_batches[0] + ]] + return ..., new_state_batches, {} + + def learn_on_batch(self, samples): + # can access array of the state elements at each timestep + # or state_in_1, 2, etc. if there are multiple state elements + assert "state_in_0" in samples.keys() + assert "state_out_0" in samples.keys() + + +**Extra action info output**: You can also emit extra outputs at each step which will be available for learning on. For example, you might want to output the behaviour policy logits as extra action info, which can be used for importance weighting, but in general arbitrary values can be stored here (as long as they are convertible to numpy arrays): + +.. code-block:: python + + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + action_info_batch = { + "some_value": ["foo" for _ in obs_batch], + "other_value": [12345 for _ in obs_batch], + } + return ..., [], action_info_batch + + def learn_on_batch(self, samples): + # can access array of the extra values at each timestep + assert "some_value" in samples.keys() + assert "other_value" in samples.keys() + + +Building Policies in TensorFlow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This section covers how to build a TensorFlow RLlib policy using ``tf_policy_template.build_tf_policy()``. + +To start, you first have to define a loss function. In RLlib, loss functions are defined over batches of trajectory data produced by policy evaluation. A basic policy gradient loss that only tries to maximize the 1-step reward can be defined as follows: + +.. code-block:: python + + import tensorflow as tf + from ray.rllib.policy.sample_batch import SampleBatch + + def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + +In the above snippet, ``actions`` is a Tensor placeholder of shape ``[batch_size, action_dim...]``, and ``rewards`` is a placeholder of shape ``[batch_size]``. The ``policy.action_dist`` object is an `ActionDistribution `__ that represents the output of the neural network policy model. Passing this loss function to ``build_tf_policy`` is enough to produce a very basic TF policy: + +.. code-block:: python + + from ray.rllib.policy.tf_policy_template import build_tf_policy + + # + MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss) + +We can create a `Trainer <#trainers>`__ and try running this policy on a toy env with two parallel rollout workers: + +.. code-block:: python + + import ray + from ray import tune + from ray.rllib.agents.trainer_template import build_trainer + + # + MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTFPolicy) + + ray.init() + tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2}) + + +If you run the above snippet, you'll probably notice that CartPole doesn't learn so well: + +.. code-block:: bash + + == Status == + Using FIFO scheduling algorithm. + Resources requested: 3/4 CPUs, 0/0 GPUs + Memory usage on this node: 4.6/12.3 GB + Result logdir: /home/ubuntu/ray_results/MyAlgTrainer + Number of trials: 1 ({'RUNNING': 1}) + RUNNING trials: + - MyAlgTrainer_CartPole-v0_0: RUNNING, [3 CPUs, 0 GPUs], [pid=26784], + 32 s, 156 iter, 62400 ts, 23.1 rew + +Let's modify our policy loss to include rewards summed over time. To enable this advantage calculation, we need to define a *trajectory postprocessor* for the policy. This can be done by defining ``postprocess_fn``: + +.. code-block:: python + + from ray.rllib.evaluation.postprocessing import compute_advantages, \ + Postprocessing + + def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) + + def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + advantages = batch_tensors[Postprocessing.ADVANTAGES] + return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages) + + MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss, + postprocess_fn=postprocess_advantages) + +The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the trainer with this improved policy, you'll find that it quickly achieves the max reward of 200. + +You might be wondering how RLlib makes the advantages placeholder automatically available as ``batch_tensors[Postprocessing.ADVANTAGES]``. When building your policy, RLlib will create a "dummy" trajectory batch where all observations, actions, rewards, etc. are zeros. It then calls your ``postprocess_fn``, and generates TF placeholders based on the numpy shapes of the postprocessed batch. RLlib tracks which placeholders that ``loss_fn`` and ``stats_fn`` access, and then feeds the corresponding sample data into those placeholders during loss optimization. You can also access these placeholders via ``policy.get_placeholder()`` after loss initialization. + +**Example 1: Proximal Policy Optimization** + +In the above section you saw how to compose a simple policy gradient algorithm with RLlib. In this example, we'll dive into how PPO was built with RLlib and how you can modify it. First, check out the `PPO trainer definition `__: + +.. code-block:: python + + PPOTrainer = build_trainer( + name="PPOTrainer", + default_config=DEFAULT_CONFIG, + default_policy=PPOTFPolicy, + make_policy_optimizer=choose_policy_optimizer, + validate_config=validate_config, + after_optimizer_step=update_kl, + before_train_step=warn_about_obs_filter, + after_train_result=warn_about_bad_reward_scales) + +Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``. + +The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer (the default), or a multi-GPU optimizer that implements minibatch SGD: + +.. code-block:: python + + def choose_policy_optimizer(workers, config): + if config["simple_optimizer"]: + return SyncSamplesOptimizer( + workers, + num_sgd_iter=config["num_sgd_iter"], + train_batch_size=config["train_batch_size"]) + + return LocalMultiGPUOptimizer( + workers, + sgd_batch_size=config["sgd_minibatch_size"], + num_sgd_iter=config["num_sgd_iter"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + num_envs_per_worker=config["num_envs_per_worker"], + train_batch_size=config["train_batch_size"], + standardize_fields=["advantages"], + straggler_mitigation=config["straggler_mitigation"]) + +Suppose we want to customize PPO to use an asynchronous-gradient optimization strategy similar to A3C. To do that, we could define a new function that returns ``AsyncGradientsOptimizer`` and pass in ``make_policy_optimizer=make_async_optimizer`` when building the trainer: + +.. code-block:: python + + from ray.rllib.agents.ppo.ppo_policy import * + from ray.rllib.optimizers import AsyncGradientsOptimizer + from ray.rllib.policy.tf_policy_template import build_tf_policy + + def make_async_optimizer(workers, config): + return AsyncGradientsOptimizer(workers, grads_per_step=100) + + PPOTrainer = build_trainer( + ..., + make_policy_optimizer=make_async_optimizer) + + +Now let's take a look at the ``update_kl`` function. This is used to adaptively adjust the KL penalty coefficient on the PPO loss, which bounds the policy change per training step. You'll notice the code handles both single and multi-agent cases (where there are be multiple policies each with different KL coeffs): + +.. code-block:: python + + def update_kl(trainer, fetches): + if "kl" in fetches: + # single-agent + trainer.workers.local_worker().for_policy( + lambda pi: pi.update_kl(fetches["kl"])) + else: + + def update(pi, pi_id): + if pi_id in fetches: + pi.update_kl(fetches[pi_id]["kl"]) + else: + logger.debug("No data for {}, not updating kl".format(pi_id)) + + # multi-agent + trainer.workers.local_worker().foreach_trainable_policy(update) + +The ``update_kl`` method on the policy is defined in `PPOTFPolicy `__ via the ``KLCoeffMixin``, along with several other advanced features. Let's look at each new feature used by the policy: + +.. code-block:: python + + PPOTFPolicy = build_tf_policy( + name="PPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, + loss_fn=ppo_surrogate_loss, + stats_fn=kl_and_loss_stats, + extra_action_fetches_fn=vf_preds_and_logits_fetches, + postprocess_fn=postprocess_ppo_gae, + gradients_fn=clip_gradients, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin]) + +``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the trainer to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics: + +.. code-block:: python + + def kl_and_loss_stats(policy, batch_tensors): + policy.explained_variance = explained_variance( + batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function) + + stats_fetches = { + "cur_kl_coeff": policy.kl_coeff, + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, + "vf_explained_var": policy.explained_variance, + "kl": policy.loss_obj.mean_kl, + "entropy": policy.loss_obj.mean_entropy, + } + + return stats_fetches + +``extra_actions_fetches_fn``: This function defines extra outputs that will be recorded when generating actions with the policy. For example, this enables saving the raw policy logits in the experience batch, which e.g. means it can be referenced in the PPO loss function via ``batch_tensors[BEHAVIOUR_LOGITS]``. Other values such as the current value prediction can also be emitted for debugging or optimization purposes: + +.. code-block:: python + + def vf_preds_and_logits_fetches(policy): + return { + SampleBatch.VF_PREDS: policy.value_function, + BEHAVIOUR_LOGITS: policy.model.outputs, + } + +``gradients_fn``: If defined, this function returns TF gradients for the loss function. You'd typically only want to override this to apply transformations such as gradient clipping: + +.. code-block:: python + + def clip_gradients(policy, optimizer, loss): + if policy.config["grad_clip"] is not None: + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, + policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + else: + return optimizer.compute_gradients( + loss, colocate_gradients_with_ops=True) + +``mixins``: To add arbitrary stateful components, you can add mixin classes to the policy. Methods defined by these mixins will have higher priority than the base policy class, so you can use these to override methods (as in the case of ``LearningRateSchedule``), or define extra methods and attributes (e.g., ``KLCoeffMixin``, ``ValueNetworkMixin``). Like any other Python superclass, these should be initialized at some point, which is what the ``setup_mixins`` function does: + +.. code-block:: python + + def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + +In PPO we run ``setup_mixins`` before the loss function is called (i.e., ``before_loss_init``), but other callbacks you can use include ``before_init`` and ``after_init``. + +**Example 2: Deep Q Networks** + +(todo) + +Finally, note that you do not have to use ``build_tf_policy`` to define a TensorFlow policy. You can alternatively subclass ``Policy``, ``TFPolicy``, or ``DynamicTFPolicy`` as convenient. + +Building Policies in PyTorch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Building on the TF examples above, let's look at how the `A3C torch policy `__ is defined: + +.. code-block:: python + + A3CTorchPolicy = build_torch_policy( + name="A3CTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=loss_and_entropy_stats, + postprocess_fn=add_advantages, + extra_action_out_fn=model_value_predictions, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=torch_optimizer, + mixins=[ValueNetworkMixin]) + +``loss_fn``: Similar to the TF example, the actor critic loss is defined over ``batch_tensors``. We imperatively execute the forward pass by calling ``policy.model()`` on the observations followed by ``policy.dist_class()`` on the output logits. The output Tensors are saved as attributes of the policy object (e.g., ``policy.entropy = dist.entropy.mean()``), and we return the scalar loss: + +.. code-block:: python + + def actor_critic_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + dist = policy.dist_class(logits) + log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) + policy.entropy = dist.entropy().mean() + ... + return overall_err + +``stats_fn``: The stats function references ``entropy``, ``pi_err``, and ``value_err`` saved from the call to the loss function, similar in the PPO TF example: + +.. code-block:: python + + def loss_and_entropy_stats(policy, batch_tensors): + return { + "policy_entropy": policy.entropy.item(), + "policy_loss": policy.pi_err.item(), + "vf_loss": policy.value_err.item(), + } + +``extra_action_out_fn``: We save value function predictions given model outputs. This makes the value function predictions of the model available in the trajectory as ``batch_tensors[SampleBatch.VF_PREDS]``: + +.. code-block:: python + + def model_value_predictions(policy, model_out): + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} + +``postprocess_fn`` and ``mixins``: Similar to the PPO example, we need access to the value function during postprocessing (i.e., ``add_advantages`` below calls ``policy._value()``. The value function is exposed through a mixin class that defines the method: + +.. code-block:: python + + def add_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + class ValueNetworkMixin(object): + def _value(self, obs): + with self.lock: + obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) + _, _, vf, _ = self.model({"obs": obs}, []) + return vf.detach().cpu().numpy().squeeze() + +You can find the full policy definition in `a3c_torch_policy.py `__. + +In summary, the main differences between the PyTorch and TensorFlow policy builder functions is that the TF loss and stats functions are built symbolically when the policy is initialized, whereas for PyTorch these functions are called imperatively each time they are used. + Policy Evaluation ----------------- -Given an environment and policy graph, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `PolicyEvaluator `__ class that manages all of this, and this class is used in most RLlib algorithms. +Given an environment and policy, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `RolloutWorker `__ class that manages all of this, and this class is used in most RLlib algorithms. -You can use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvaluator.as_remote()``). +You can use rollout workers standalone to produce batches of experiences. This can be done by calling ``worker.sample()`` on a worker instance, or ``worker.sample.remote()`` in parallel on worker instances created as Ray actors (see ``RolloutWorkers.create_remote``). -Here is an example of creating a set of policy evaluation actors and using the to gather experiences in parallel. The trajectories are concatenated, the policy learns on the trajectory batch, and then we broadcast the policy weights to the evaluators for the next round of rollouts: +Here is an example of creating a set of rollout workers and using them gather experiences in parallel. The trajectories are concatenated, the policy learns on the trajectory batch, and then we broadcast the policy weights to the workers for the next round of rollouts: .. code-block:: python - # Setup policy and remote policy evaluation actors + # Setup policy and rollout workers env = gym.make("CartPole-v0") policy = CustomPolicy(env.observation_space, env.action_space, {}) - remote_evaluators = [ - PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"), - CustomPolicy) - for _ in range(10) - ] + workers = WorkerSet( + policy=CustomPolicy, + env_creator=lambda c: gym.make("CartPole-v0"), + num_workers=10) while True: # Gather a batch of samples T1 = SampleBatch.concat_samples( - ray.get([w.sample.remote() for w in remote_evaluators])) + ray.get([w.sample.remote() for w in workers.remote_workers()])) # Improve the policy using the T1 batch policy.learn_on_batch(T1) # Broadcast weights to the policy evaluation workers weights = ray.put({"default_policy": policy.get_weights()}) - for w in remote_evaluators: + for w in workers.remote_workers(): w.set_weights.remote(weights) Policy Optimization ------------------- -Similar to how a `gradient-descent optimizer `__ can be used to improve a model, RLlib's `policy optimizers `__ implement different strategies for improving a policy graph. +Similar to how a `gradient-descent optimizer `__ can be used to improve a model, RLlib's `policy optimizers `__ implement different strategies for improving a policy. -For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy graph replica. This strategy is implemented by the `AsyncGradientsOptimizer `__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer `__. Policy optimizers abstract these strategies away into reusable modules. +For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy replica. This strategy is implemented by the `AsyncGradientsOptimizer `__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer `__. Policy optimizers abstract these strategies away into reusable modules. This is how the example in the previous section looks when written using a policy optimizer: .. code-block:: python # Same setup as before - local_evaluator = PolicyEvaluator(lambda c: gym.make("CartPole-v0"), CustomPolicy) - remote_evaluators = [ - PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"), - CustomPolicy) - for _ in range(10) - ] + workers = WorkerSet( + policy=CustomPolicy, + env_creator=lambda c: gym.make("CartPole-v0"), + num_workers=10) # this optimizer implements the IMPALA architecture - optimizer = AsyncSamplesOptimizer( - local_evaluator, remote_evaluators, train_batch_size=500) + optimizer = AsyncSamplesOptimizer(workers, train_batch_size=500) while True: optimizer.step() @@ -108,9 +482,9 @@ This is how the example in the previous section looks when written using a polic Trainers -------- -Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API `__ for easy experiment management. +Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the rollout workers and optimizer, and collection of training metrics. Trainers also implement the `Trainable API `__ for easy experiment management. -Example of two equivalent ways of interacting with the PPO trainer: +Example of three equivalent ways of interacting with the PPO trainer, all of which log results in ``~/ray_results``: .. code-block:: python @@ -121,3 +495,8 @@ Example of two equivalent ways of interacting with the PPO trainer: .. code-block:: bash rllib train --run=PPO --env=CartPole-v0 --config='{"train_batch_size": 4000}' + +.. code-block:: python + + from ray import tune + tune.run(PPOTrainer, config={"env": "CartPole-v0", "train_batch_size": 4000}) diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 056b7c3fc791..b04b91c3c265 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -13,7 +13,7 @@ Algorithm Discrete Actions Continuous Actions Multi-Agent Recurre A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes** PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** **Yes** PG **Yes** `+parametric`_ **Yes** **Yes** **Yes** -IMPALA **Yes** `+parametric`_ No **Yes** **Yes** +IMPALA **Yes** `+parametric`_ **Yes** **Yes** **Yes** DQN, Rainbow **Yes** `+parametric`_ No **Yes** No DDPG, TD3 No **Yes** **Yes** No APEX-DQN **Yes** `+parametric`_ No **Yes** No @@ -92,7 +92,7 @@ In the above example, note that the ``env_creator`` function takes in an ``env_c OpenAI Gym ---------- -RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may also find the `SimpleCorridor `__ and `Carla simulator `__ example env implementations useful as a reference. +RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may find the `SimpleCorridor `__ example useful as a reference. Performance ~~~~~~~~~~~ @@ -167,8 +167,8 @@ If all the agents will be using the same algorithm class to train, then you can trainer = pg.PGAgent(env="my_multiagent_env", config={ "multiagent": { - "policy_graphs": { - # the first tuple value is None -> uses default policy graph + "policies": { + # the first tuple value is None -> uses default policy "car1": (None, car_obs_space, car_act_space, {"gamma": 0.85}), "car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}), "traffic_light": (None, tl_obs_space, tl_act_space, {}), @@ -234,10 +234,10 @@ This can be implemented as a multi-agent environment with three types of agents. .. code-block:: python "multiagent": { - "policy_graphs": { - "top_level": (custom_policy_graph or None, ...), - "mid_level": (custom_policy_graph or None, ...), - "low_level": (custom_policy_graph or None, ...), + "policies": { + "top_level": (custom_policy or None, ...), + "mid_level": (custom_policy or None, ...), + "low_level": (custom_policy or None, ...), }, "policy_mapping_fn": lambda agent_id: @@ -269,13 +269,13 @@ There is a full example of this in the `example training script `__. +2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicy example `__. Grouping Agents ~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 7fd860a65a3e..cdf42ea228c7 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -101,7 +101,7 @@ Custom TF models should subclass the common RLlib `model class `__ and associated `training scripts `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. +For a full example of a custom model in code, see the `custom env example `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. Custom Recurrent Models ~~~~~~~~~~~~~~~~~~~~~~~ @@ -175,7 +175,7 @@ Instead of using the ``use_lstm: True`` option, it can be preferable use a custo Batch Normalization ~~~~~~~~~~~~~~~~~~~ -You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example `__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy_graph.py `__ and `multi_gpu_impl.py `__ for the exact handling of these updates). +You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example `__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy.py `__ and `multi_gpu_impl.py `__ for the exact handling of these updates). Custom Models (PyTorch) ----------------------- @@ -263,7 +263,7 @@ You can mix supervised losses into any RLlib algorithm through custom models. Fo **TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``custom_metrics()`` method. Here is a `runnable example `__ of adding an imitation loss to CartPole training that is defined over a `offline dataset `__. -**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy graph definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass. +**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass. Variable-length / Parametric Action Spaces @@ -312,15 +312,15 @@ Custom models can be used to work with environments where (1) the set of valid a Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. -Customizing Policy Graphs +Customizing Policies ------------------------- -For deeper customization of algorithms, you can modify the policy graphs of the trainer classes. Here's an example of extending the DDPG policy graph to specify custom sub-network modules: +For deeper customization of algorithms, you can modify the policies of the trainer classes. Here's an example of extending the DDPG policy to specify custom sub-network modules: .. code-block:: python from ray.rllib.models import ModelCatalog - from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph as BaseDDPGPolicyGraph + from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy as BaseDDPGTFPolicy class CustomPNetwork(object): def __init__(self, dim_actions, hiddens, activation): @@ -336,7 +336,7 @@ For deeper customization of algorithms, you can modify the policy graphs of the self.value = layers.fully_connected( q_out, num_outputs=1, activation_fn=None) - class CustomDDPGPolicyGraph(BaseDDPGPolicyGraph): + class CustomDDPGTFPolicy(BaseDDPGTFPolicy): def _build_p_network(self, obs): return CustomPNetwork( self.dim_actions, @@ -349,26 +349,26 @@ For deeper customization of algorithms, you can modify the policy graphs of the self.config["critic_hiddens"], self.config["critic_hidden_activation"]).value -Then, you can create an trainer with your custom policy graph by: +Then, you can create an trainer with your custom policy by: .. code-block:: python from ray.rllib.agents.ddpg.ddpg import DDPGTrainer - from custom_policy_graph import CustomDDPGPolicyGraph + from custom_policy import CustomDDPGTFPolicy - DDPGTrainer._policy_graph = CustomDDPGPolicyGraph + DDPGTrainer._policy = CustomDDPGTFPolicy trainer = DDPGTrainer(...) -In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely. +In this example we overrode existing methods of the existing DDPG policy, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely. Model-Based Rollouts ~~~~~~~~~~~~~~~~~~~~ -With a custom policy graph, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicyGraph for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy graph: +With a custom policy, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicy for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy: .. code-block:: python - class ModelBasedPolicyGraph(PGPolicyGraph): + class ModelBasedPolicy(PGPolicy): def compute_actions(self, obs_batch, state_batches, diff --git a/doc/source/rllib-offline.rst b/doc/source/rllib-offline.rst index 42dd5f5b4909..825038af3d53 100644 --- a/doc/source/rllib-offline.rst +++ b/doc/source/rllib-offline.rst @@ -6,7 +6,7 @@ Working with Offline Datasets RLlib's offline dataset APIs enable working with experiences read from offline storage (e.g., disk, cloud storage, streaming systems, HDFS). For example, you might want to read experiences saved from previous training runs, or gathered from policies deployed in `web applications `__. You can also log new agent experiences produced during online training for future use. -RLlib represents trajectory sequences (i.e., ``(s, a, r, s', ...)`` tuples) with `SampleBatch `__ objects. Using a batch format enables efficient encoding and compression of experiences. During online training, RLlib uses `policy evaluation `__ actors to generate batches of experiences in parallel using the current policy. RLlib also uses this same batch format for reading and writing experiences to offline storage. +RLlib represents trajectory sequences (i.e., ``(s, a, r, s', ...)`` tuples) with `SampleBatch `__ objects. Using a batch format enables efficient encoding and compression of experiences. During online training, RLlib uses `policy evaluation `__ actors to generate batches of experiences in parallel using the current policy. RLlib also uses this same batch format for reading and writing experiences to offline storage. Example: Training on previously saved experiences ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -65,7 +65,7 @@ This example plot shows the Q-value metric in addition to importance sampling (I .. image:: offline-q.png -**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy graph object and gamma value for the environment: +**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy object and gamma value for the environment: .. code-block:: python @@ -99,7 +99,7 @@ This `runnable example `__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing `__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data. +RLlib assumes that input batches are of `postprocessed experiences `__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing `__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data. However, for on-policy algorithms like PPO, you'll need to pass in the extra values added during policy evaluation and postprocessing to ``batch_builder.add_values()``, e.g., ``logits``, ``vf_preds``, ``value_target``, and ``advantages`` for PPO. This is needed since the calculation of these values depends on the parameters of the *behaviour* policy, which RLlib does not have access to in the offline setting (in online training, these values are automatically added during policy evaluation). diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 06c580035507..e77a0ab427f8 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -5,8 +5,7 @@ RLlib is an open-source library for reinforcement learning that offers both high .. image:: rllib-stack.svg -Learn more about RLlib's design by reading the `ICML paper `__. -To get started, take a look over the `custom env example `__ and the `API documentation `__. +To get started, take a look over the `custom env example `__ and the `API documentation `__. If you're looking to develop custom algorithms with RLlib, also check out `concepts and custom algorithms `__. Installation ------------ @@ -50,7 +49,7 @@ Models and Preprocessors * `Custom Preprocessors `__ * `Supervised Model Losses `__ * `Variable-length / Parametric Action Spaces `__ -* `Customizing Policy Graphs `__ +* `Customizing Policies `__ Algorithms ---------- @@ -96,12 +95,17 @@ Offline Datasets * `Input API `__ * `Output API `__ -Concepts --------- -* `Policy Graphs `__ -* `Policy Evaluation `__ -* `Policy Optimization `__ -* `Trainers `__ +Concepts and Building Custom Algorithms +--------------------------------------- +* `Policies `__ + + - `Building Policies in TensorFlow `__ + + - `Building Policies in PyTorch `__ + +* `Policy Evaluation `__ +* `Policy Optimization `__ +* `Trainers `__ Examples -------- diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 0a2bf491a676..2dae8eaf4abe 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -171,7 +171,9 @@ This algorithm requires specifying a search space and objective. You can use `Ax .. code-block:: python - tune.run(... , search_alg=AxSearch(parameter_dicts, ... )) + client = AxClient(enforce_sequential_optimization=False) + client.create_experiment( ... ) + tune.run(... , search_alg=AxSearch(client)) An example of this can be found in `ax_example.py `__. diff --git a/doc/source/tune.rst b/doc/source/tune.rst index bfeb729e6469..2674f5b064a7 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -7,9 +7,9 @@ Tune: Scalable Hyperparameter Search Tune is a scalable framework for hyperparameter search with a focus on deep learning and deep reinforcement learning. -You can find the code for Tune `here on GitHub `__. To get started with Tune, try going through `our tutorial of using Tune with Keras `__. +You can find the code for Tune `here on GitHub `__. To get started with Tune, try going through `our tutorial of using Tune with Keras `__. -(Experimental): You can try out `the above tutorial on a free hosted server via Binder `__. +(Experimental): You can try out `the above tutorial on a free hosted server via Binder `__. Features diff --git a/doc/source/tutorial.rst b/doc/source/tutorial.rst index c1b26d155e8c..889ca77c50c3 100644 --- a/doc/source/tutorial.rst +++ b/doc/source/tutorial.rst @@ -9,9 +9,9 @@ To use Ray, you need to understand the following: Overview -------- -Ray is a distributed execution engine. The same code can be run on -a single machine to achieve efficient multiprocessing, and it can be used on a -cluster for large computations. +Ray is a fast and simple framework for building and running distributed applications. +The same code can be run on a single machine to achieve efficient multiprocessing, +and it can be used on a cluster for large computations. When using Ray, several processes are involved. diff --git a/doc/source/user-profiling.rst b/doc/source/user-profiling.rst index 4bf152e52a00..511531f061a8 100644 --- a/doc/source/user-profiling.rst +++ b/doc/source/user-profiling.rst @@ -18,7 +18,7 @@ following command. .. code-block:: python - ray.global_state.chrome_tracing_dump(filename="/tmp/timeline.json") + ray.timeline(filename="/tmp/timeline.json") Then open `chrome://tracing`_ in the Chrome web browser, and load ``timeline.json``. diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 0891ac02c8f9..664370eb0479 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 9546b676b779..b0cf426c1b1d 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 34799a76c78a..f86df8d40f96 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module", "gen_java_pom_file") +load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") exports_files([ "testng.xml", @@ -7,27 +7,29 @@ exports_files([ "streaming/testng.xml", ]) +all_modules = [ + "api", + "runtime", + "test", + "tutorial", + "streaming", +] + java_import( name = "all_modules", jars = [ - "liborg_ray_ray_api.jar", - "liborg_ray_ray_api-src.jar", - "liborg_ray_ray_runtime.jar", - "liborg_ray_ray_runtime-src.jar", - "liborg_ray_ray_tutorial.jar", - "liborg_ray_ray_tutorial-src.jar", - "liborg_ray_ray_streaming.jar", - "liborg_ray_ray_streaming-src.jar", + "liborg_ray_ray_" + module + ".jar" for module in all_modules + ] + [ + "liborg_ray_ray_" + module + "-src.jar" for module in all_modules + ] + [ "all_tests_deploy.jar", "all_tests_deploy-src.jar", "streaming_tests_deploy.jar", "streaming_tests_deploy-src.jar", ], deps = [ - ":org_ray_ray_api", - ":org_ray_ray_runtime", - ":org_ray_ray_tutorial", - ":org_ray_ray_streaming", + ":org_ray_ray_" + module for module in all_modules + ] + [ ":all_tests", ":streaming_tests", ], @@ -154,6 +156,7 @@ flatbuffers_generated_files = [ "ConfigTableData.java", "CustomSerializerData.java", "DriverTableData.java", + "EntryType.java", "ErrorTableData.java", "ErrorType.java", "FunctionTableData.java", @@ -246,30 +249,10 @@ genrule( local = 1, ) -# generate pom.xml file for maven compile -gen_java_pom_file( - name = "api", -) - -gen_java_pom_file( - name = "runtime", -) - -gen_java_pom_file( - name = "tutorial", -) - -gen_java_pom_file( - name = "test", -) - genrule( name = "copy_pom_file", srcs = [ - "//java:org_ray_ray_api_pom", - "//java:org_ray_ray_runtime_pom", - "//java:org_ray_ray_tutorial_pom", - "//java:org_ray_ray_test_pom", + "//java:org_ray_ray_" + module + "_pom" for module in all_modules ], outs = ["copy_pom_file.out"], cmd = """ @@ -279,6 +262,7 @@ genrule( cp -f $(location //java:org_ray_ray_runtime_pom) $$WORK_DIR/java/runtime/pom.xml cp -f $(location //java:org_ray_ray_tutorial_pom) $$WORK_DIR/java/tutorial/pom.xml cp -f $(location //java:org_ray_ray_test_pom) $$WORK_DIR/java/test/pom.xml + cp -f $(location //java:org_ray_ray_streaming_pom) $$WORK_DIR/java/streaming/pom.xml echo $$(date) > $@ """, local = 1, diff --git a/java/api/pom.xml b/java/api/pom.xml index c7a910cd989f..792e54f6c433 100644 --- a/java/api/pom.xml +++ b/java/api/pom.xml @@ -1,4 +1,5 @@ + @@ -16,21 +17,30 @@ jar - - org.slf4j - slf4j-log4j12 - - - javax.xml.bind - jaxb-api - - - com.sun.xml.bind - jaxb-core - - - com.sun.xml.bind - jaxb-impl - + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + javax.xml.bind + jaxb-api + 2.3.0 + + + log4j + log4j + 1.2.17 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + diff --git a/java/api/pom_template.xml b/java/api/pom_template.xml index ae37175a812a..67088f9584cb 100644 --- a/java/api/pom_template.xml +++ b/java/api/pom_template.xml @@ -1,4 +1,5 @@ +{auto_gen_header} @@ -16,6 +17,6 @@ jar - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index fa82ea685706..cdad95e16758 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -1,6 +1,7 @@ package org.ray.api; import java.util.List; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.runtime.RayRuntime; import org.ray.api.runtime.RayRuntimeFactory; @@ -65,7 +66,7 @@ public static RayObject put(T obj) { * @param objectId The ID of the object to get. * @return The Java object. */ - public static T get(UniqueId objectId) { + public static T get(ObjectId objectId) { return runtime.get(objectId); } @@ -75,7 +76,7 @@ public static T get(UniqueId objectId) { * @param objectIds The list of object IDs. * @return A list of Java objects. */ - public static List get(List objectIds) { + public static List get(List objectIds) { return runtime.get(objectIds); } @@ -123,6 +124,21 @@ public static RayRuntime internal() { return runtime; } + /** + * Update the resource for the specified client. + * Set the resource for the specific node. + */ + public static void setResource(UniqueId nodeId, String resourceName, double capacity) { + runtime.setResource(resourceName, capacity, nodeId); + } + + /** + * Set the resource for local node. + */ + public static void setResource(String resourceName, double capacity) { + runtime.setResource(resourceName, capacity, UniqueId.NIL); + } + /** * Get the runtime context. */ diff --git a/java/api/src/main/java/org/ray/api/RayObject.java b/java/api/src/main/java/org/ray/api/RayObject.java index a1971be40773..faf42f826aa1 100644 --- a/java/api/src/main/java/org/ray/api/RayObject.java +++ b/java/api/src/main/java/org/ray/api/RayObject.java @@ -1,6 +1,6 @@ package org.ray.api; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Represents an object in the object store. @@ -17,7 +17,7 @@ public interface RayObject { /** * Get the object id. */ - UniqueId getId(); + ObjectId getId(); } diff --git a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java index 8362295baf1a..0eb2ed9e7dca 100644 --- a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java +++ b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java @@ -1,6 +1,6 @@ package org.ray.api.exception; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Indicates that an object is lost (either evicted or explicitly deleted) and cannot be @@ -11,9 +11,9 @@ */ public class UnreconstructableException extends RayException { - public final UniqueId objectId; + public final ObjectId objectId; - public UnreconstructableException(UniqueId objectId) { + public UnreconstructableException(ObjectId objectId) { super(String.format( "Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.", objectId)); diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java new file mode 100644 index 000000000000..3c5e1e3a3619 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -0,0 +1,99 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.xml.bind.DatatypeConverter; + +public abstract class BaseId implements Serializable { + private static final long serialVersionUID = 8588849129675565761L; + private final byte[] id; + private int hashCodeCache = 0; + private Boolean isNilCache = null; + + /** + * Create a BaseId instance according to the input byte array. + */ + public BaseId(byte[] id) { + if (id.length != size()) { + throw new IllegalArgumentException("Failed to construct BaseId, expect " + size() + + " bytes, but got " + id.length + " bytes."); + } + this.id = id; + } + + /** + * Get the byte data of this id. + */ + public byte[] getBytes() { + return id; + } + + /** + * Convert the byte data to a ByteBuffer. + */ + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(id); + } + + /** + * @return True if this id is nil. + */ + public boolean isNil() { + if (isNilCache == null) { + isNilCache = true; + for (int i = 0; i < size(); ++i) { + if (id[i] != (byte) 0xff) { + isNilCache = false; + break; + } + } + } + return isNilCache; + } + + /** + * Derived class should implement this function. + * @return The length of this id in bytes. + */ + public abstract int size(); + + @Override + public int hashCode() { + // Lazy evaluation. + if (hashCodeCache == 0) { + hashCodeCache = Arrays.hashCode(id); + } + return hashCodeCache; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (!this.getClass().equals(obj.getClass())) { + return false; + } + + BaseId r = (BaseId) obj; + return Arrays.equals(id, r.id); + } + + @Override + public String toString() { + return DatatypeConverter.printHexBinary(id).toLowerCase(); + } + + protected static byte[] hexString2Bytes(String hex) { + return DatatypeConverter.parseHexBinary(hex); + } + + protected static byte[] byteBuffer2Bytes(ByteBuffer bb) { + byte[] id = new byte[bb.remaining()]; + bb.get(id); + return id; + } + +} diff --git a/java/api/src/main/java/org/ray/api/id/ObjectId.java b/java/api/src/main/java/org/ray/api/id/ObjectId.java new file mode 100644 index 000000000000..49c0f39ebe5b --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/ObjectId.java @@ -0,0 +1,62 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a Ray object. + */ +public class ObjectId extends BaseId implements Serializable { + + public static final int LENGTH = 20; + public static final ObjectId NIL = genNil(); + + /** + * Create an ObjectId from a hex string. + */ + public static ObjectId fromHexString(String hex) { + return new ObjectId(hexString2Bytes(hex)); + } + + /** + * Create an ObjectId from a ByteBuffer. + */ + public static ObjectId fromByteBuffer(ByteBuffer bb) { + return new ObjectId(byteBuffer2Bytes(bb)); + } + + /** + * Generate a nil ObjectId. + */ + private static ObjectId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new ObjectId(b); + } + + /** + * Generate an ObjectId with random value. + */ + public static ObjectId randomId() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new ObjectId(b); + } + + public ObjectId(byte[] id) { + super(id); + } + + @Override + public int size() { + return LENGTH; + } + + public TaskId getTaskId() { + byte[] taskIdBytes = Arrays.copyOf(getBytes(), TaskId.LENGTH); + return new TaskId(taskIdBytes); + } + +} diff --git a/java/api/src/main/java/org/ray/api/id/TaskId.java b/java/api/src/main/java/org/ray/api/id/TaskId.java new file mode 100644 index 000000000000..8f1fe0694ea4 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/TaskId.java @@ -0,0 +1,56 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a Ray task. + */ +public class TaskId extends BaseId implements Serializable { + + public static final int LENGTH = 16; + public static final TaskId NIL = genNil(); + + /** + * Create a TaskId from a hex string. + */ + public static TaskId fromHexString(String hex) { + return new TaskId(hexString2Bytes(hex)); + } + + /** + * Creates a TaskId from a ByteBuffer. + */ + public static TaskId fromByteBuffer(ByteBuffer bb) { + return new TaskId(byteBuffer2Bytes(bb)); + } + + /** + * Generate a nil TaskId. + */ + private static TaskId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new TaskId(b); + } + + /** + * Generate an TaskId with random value. + */ + public static TaskId randomId() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new TaskId(b); + } + + public TaskId(byte[] id) { + super(id); + } + + @Override + public int size() { + return LENGTH; + } +} diff --git a/java/api/src/main/java/org/ray/api/id/UniqueId.java b/java/api/src/main/java/org/ray/api/id/UniqueId.java index f93bdc737229..4fd723ff26bf 100644 --- a/java/api/src/main/java/org/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/org/ray/api/id/UniqueId.java @@ -4,41 +4,34 @@ import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Random; -import javax.xml.bind.DatatypeConverter; /** * Represents a unique id of all Ray concepts, including - * objects, tasks, workers, actors, etc. + * workers, actors, checkpoints, etc. */ -public class UniqueId implements Serializable { +public class UniqueId extends BaseId implements Serializable { public static final int LENGTH = 20; public static final UniqueId NIL = genNil(); - private static final long serialVersionUID = 8588849129675565761L; - private final byte[] id; /** * Create a UniqueId from a hex string. */ public static UniqueId fromHexString(String hex) { - byte[] bytes = DatatypeConverter.parseHexBinary(hex); - return new UniqueId(bytes); + return new UniqueId(hexString2Bytes(hex)); } /** * Creates a UniqueId from a ByteBuffer. */ public static UniqueId fromByteBuffer(ByteBuffer bb) { - byte[] id = new byte[bb.remaining()]; - bb.get(id); - - return new UniqueId(id); + return new UniqueId(byteBuffer2Bytes(bb)); } /** * Generate a nil UniqueId. */ - public static UniqueId genNil() { + private static UniqueId genNil() { byte[] b = new byte[LENGTH]; Arrays.fill(b, (byte) 0xFF); return new UniqueId(b); @@ -54,64 +47,11 @@ public static UniqueId randomId() { } public UniqueId(byte[] id) { - if (id.length != LENGTH) { - throw new IllegalArgumentException("Illegal argument for UniqueId, expect " + LENGTH - + " bytes, but got " + id.length + " bytes."); - } - - this.id = id; - } - - /** - * Get the byte data of this UniqueId. - */ - public byte[] getBytes() { - return id; - } - - /** - * Convert the byte data to a ByteBuffer. - */ - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(id); - } - - /** - * Create a copy of this UniqueId. - */ - public UniqueId copy() { - byte[] nid = Arrays.copyOf(id, id.length); - return new UniqueId(nid); - } - - /** - * Returns true if this id is nil. - */ - public boolean isNil() { - return this.equals(NIL); - } - - @Override - public int hashCode() { - return Arrays.hashCode(id); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - - if (!(obj instanceof UniqueId)) { - return false; - } - - UniqueId r = (UniqueId) obj; - return Arrays.equals(id, r.id); + super(id); } @Override - public String toString() { - return DatatypeConverter.printHexBinary(id).toLowerCase(); + public int size() { + return LENGTH; } } diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 521032316366..5a29c9a39dd1 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -6,6 +6,7 @@ import org.ray.api.RayPyActor; import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; @@ -35,7 +36,7 @@ public interface RayRuntime { * @param objectId The ID of the object to get. * @return The Java object. */ - T get(UniqueId objectId); + T get(ObjectId objectId); /** * Get a list of objects from the object store. @@ -43,7 +44,7 @@ public interface RayRuntime { * @param objectIds The list of object IDs. * @return A list of Java objects. */ - List get(List objectIds); + List get(List objectIds); /** * Wait for a list of RayObjects to be locally available, until specified number of objects are @@ -63,7 +64,16 @@ public interface RayRuntime { * @param localOnly Whether only free objects for local object store or not. * @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS. */ - void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + + /** + * Set the resource for the specific node. + * + * @param resourceName The name of resource. + * @param capacity The capacity of the resource. + * @param nodeId The node that we want to set its resource. + */ + void setResource(String resourceName, double capacity, UniqueId nodeId); /** * Invoke a remote function. diff --git a/java/pom.xml b/java/pom.xml index ce5ffa2faa29..bf7a41229b9b 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -20,8 +20,6 @@ 1.8 UTF-8 0.1-SNAPSHOT - 1.7.25 - 2.3.0 @@ -31,76 +29,6 @@ arrow-plasma 0.13.0-SNAPSHOT - - de.ruedigermoeller - fst - 2.47 - - - org.ow2.asm - asm - 6.0 - - - com.github.davidmoten - flatbuffers-java - 1.9.0.1 - - - com.beust - jcommander - 1.72 - - - redis.clients - jedis - 2.8.0 - - - commons-io - commons-io - 2.5 - - - org.apache.commons - commons-lang3 - 3.4 - - - com.google.guava - guava - 19.0 - - - org.slf4j - slf4j-log4j12 - ${slf4j.version} - - - com.typesafe - config - 1.3.2 - - - org.testng - testng - 6.9.9 - - - javax.xml.bind - jaxb-api - ${jaxb.version} - - - com.sun.xml.bind - jaxb-core - ${jaxb.version} - - - com.sun.xml.bind - jaxb-impl - ${jaxb.version} - diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 4b2cc7d50373..1ce51971c03e 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -1,4 +1,5 @@ + @@ -21,53 +22,70 @@ ray-api ${project.version} - - com.typesafe - config - - - org.apache.commons - commons-lang3 - - - de.ruedigermoeller - fst - - - com.github.davidmoten - flatbuffers-java - - - redis.clients - jedis - org.apache.arrow arrow-plasma - - commons-io - commons-io - - - com.google.guava - guava - - - org.slf4j - slf4j-log4j12 - - - org.ow2.asm - asm - - - - - org.testng - testng - test - + + com.beust + jcommander + 1.72 + + + com.github.davidmoten + flatbuffers-java + 1.9.0.1 + + + com.google.guava + guava + 27.0.1-jre + + + com.typesafe + config + 1.3.2 + + + commons-io + commons-io + 2.5 + + + de.ruedigermoeller + fst + 2.47 + + + org.apache.commons + commons-lang3 + 3.4 + + + org.ow2.asm + asm + 6.0 + + + org.slf4j + slf4j-api + 1.7.25 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + + + org.testng + testng + 6.9.9 + + + redis.clients + jedis + 2.8.0 + diff --git a/java/runtime/pom_template.xml b/java/runtime/pom_template.xml index fc75efe70398..9200bd6c6003 100644 --- a/java/runtime/pom_template.xml +++ b/java/runtime/pom_template.xml @@ -1,4 +1,5 @@ +{auto_gen_header} @@ -25,7 +26,7 @@ org.apache.arrow arrow-plasma - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index af8cff9d79d9..fbd03bf10483 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -15,6 +15,8 @@ import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.BaseTaskOptions; @@ -32,7 +34,7 @@ import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,15 +90,15 @@ public AbstractRayRuntime(RayConfig rayConfig) { @Override public RayObject put(T obj) { - UniqueId objectId = UniqueIdUtil.computePutId( + ObjectId objectId = IdUtil.computePutId( workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); put(objectId, obj); return new RayObjectImpl<>(objectId); } - public void put(UniqueId objectId, T obj) { - UniqueId taskId = workerContext.getCurrentTaskId(); + public void put(ObjectId objectId, T obj) { + TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj); } @@ -109,28 +111,28 @@ public void put(UniqueId objectId, T obj) { * @return A RayObject instance that represents the in-store object. */ public RayObject putSerialized(byte[] obj) { - UniqueId objectId = UniqueIdUtil.computePutId( + ObjectId objectId = IdUtil.computePutId( workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); - UniqueId taskId = workerContext.getCurrentTaskId(); + TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); objectStoreProxy.putSerialized(objectId, obj); return new RayObjectImpl<>(objectId); } @Override - public T get(UniqueId objectId) throws RayException { + public T get(ObjectId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); return ret.get(0); } @Override - public List get(List objectIds) { + public List get(List objectIds) { List ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null)); boolean wasBlocked = false; try { // A map that stores the unready object ids and their original indexes. - Map unready = new HashMap<>(); + Map unready = new HashMap<>(); for (int i = 0; i < objectIds.size(); i++) { unready.put(objectIds.get(i), i); } @@ -138,7 +140,7 @@ public List get(List objectIds) { // Repeat until we get all objects. while (!unready.isEmpty()) { - List unreadyIds = new ArrayList<>(unready.keySet()); + List unreadyIds = new ArrayList<>(unready.keySet()); // For the initial fetch, we only fetch the objects, do not reconstruct them. boolean fetchOnly = numAttempts == 0; @@ -147,7 +149,7 @@ public List get(List objectIds) { wasBlocked = true; } // Call `fetchOrReconstruct` in batches. - for (List batch : splitIntoBatches(unreadyIds)) { + for (List batch : splitIntoBatches(unreadyIds)) { rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId()); } @@ -161,7 +163,7 @@ public List get(List objectIds) { throw getResult.exception; } else { // Set the result to the return list, and remove it from the unready map. - UniqueId id = unreadyIds.get(i); + ObjectId id = unreadyIds.get(i); ret.set(unready.get(id), getResult.object); unready.remove(id); } @@ -172,11 +174,11 @@ public List get(List objectIds) { if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) { // Print a warning if we've attempted too many times, but some objects are still // unavailable. - List idsToPrint = new ArrayList<>(unready.keySet()); + List idsToPrint = new ArrayList<>(unready.keySet()); if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) { idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING); } - String ids = idsToPrint.stream().map(UniqueId::toString) + String ids = idsToPrint.stream().map(ObjectId::toString) .collect(Collectors.joining(", ")); if (idsToPrint.size() < unready.size()) { ids += ", etc"; @@ -206,17 +208,26 @@ public List get(List objectIds) { } @Override - public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks); } - private List> splitIntoBatches(List objectIds) { - List> batches = new ArrayList<>(); + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + Preconditions.checkArgument(Double.compare(capacity, 0) >= 0); + if (nodeId == null) { + nodeId = UniqueId.NIL; + } + rayletClient.setResource(resourceName, capacity, nodeId); + } + + private List> splitIntoBatches(List objectIds) { + List> batches = new ArrayList<>(); int objectsSize = objectIds.size(); for (int i = 0; i < objectsSize; i += FETCH_BATCH_SIZE) { int endIndex = i + FETCH_BATCH_SIZE; - List batchIds = (endIndex < objectsSize) + List batchIds = (endIndex < objectsSize) ? objectIds.subList(i, endIndex) : objectIds.subList(i, objectsSize); @@ -262,7 +273,7 @@ public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL, args, true, options); - RayActorImpl actor = new RayActorImpl(spec.returnIds[0]); + RayActorImpl actor = new RayActorImpl(new UniqueId(spec.returnIds[0].getBytes())); actor.increaseTaskCounter(); actor.setTaskCursor(spec.returnIds[0]); rayletClient.submitTask(spec); @@ -334,14 +345,14 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes boolean isActorCreationTask, BaseTaskOptions taskOptions) { Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); - UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), + TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; - UniqueId[] returnIds = UniqueIdUtil.genReturnIds(taskId, numReturns); + ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns); UniqueId actorCreationId = UniqueId.NIL; if (isActorCreationTask) { - actorCreationId = returnIds[0]; + actorCreationId = new UniqueId(returnIds[0].getBytes()); } Map resources; @@ -379,7 +390,7 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes actor.increaseTaskCounter(), actor.getNewActorHandles().toArray(new UniqueId[0]), ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON), - returnIds, + numReturns, resources, language, functionDescriptor diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java index 7899869aef42..c5a9703c9164 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; import org.ray.api.RayActor; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.util.Sha1Digestor; @@ -30,7 +31,7 @@ public class RayActorImpl implements RayActor, Externalizable { * The unique id of the last return of the last task. * It's used as a dependency for the next task. */ - protected UniqueId taskCursor; + protected ObjectId taskCursor; /** * The number of times that this actor handle has been forked. * It's used to make sure ids of actor handles are unique. @@ -72,7 +73,7 @@ public UniqueId getHandleId() { return handleId; } - public void setTaskCursor(UniqueId taskCursor) { + public void setTaskCursor(ObjectId taskCursor) { this.taskCursor = taskCursor; } @@ -84,7 +85,7 @@ public void clearNewActorHandles() { this.newActorHandles.clear(); } - public UniqueId getTaskCursor() { + public ObjectId getTaskCursor() { return taskCursor; } @@ -121,7 +122,7 @@ public void writeExternal(ObjectOutput out) throws IOException { public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { this.id = (UniqueId) in.readObject(); this.handleId = (UniqueId) in.readObject(); - this.taskCursor = (UniqueId) in.readObject(); + this.taskCursor = (ObjectId) in.readObject(); this.taskCounter = (int) in.readObject(); this.numForks = (int) in.readObject(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java index 1516543a1e2a..9f8e567f8e09 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java @@ -3,13 +3,13 @@ import java.io.Serializable; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; public final class RayObjectImpl implements RayObject, Serializable { - private final UniqueId id; + private final ObjectId id; - public RayObjectImpl(UniqueId id) { + public RayObjectImpl(ObjectId id) { this.id = id; } @@ -19,7 +19,7 @@ public T get() { } @Override - public UniqueId getId() { + public ObjectId getId() { return id; } diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 813a62fdc07e..b4de226e2914 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -7,6 +7,7 @@ import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.RayFunction; @@ -80,7 +81,7 @@ public void loop() { */ public void execute(TaskSpec spec) { LOGGER.debug("Executing task {}", spec); - UniqueId returnId = spec.returnIds[0]; + ObjectId returnId = spec.returnIds[0]; ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); try { // Get method @@ -91,7 +92,7 @@ public void execute(TaskSpec spec) { Thread.currentThread().setContextClassLoader(rayFunction.classLoader); if (spec.isActorCreationTask()) { - currentActorId = returnId; + currentActorId = new UniqueId(returnId.getBytes()); } // Get local actor object and arguments. @@ -119,7 +120,7 @@ public void execute(TaskSpec spec) { } runtime.put(returnId, result); } else { - maybeLoadCheckpoint(result, returnId); + maybeLoadCheckpoint(result, new UniqueId(returnId.getBytes())); currentActor = result; } LOGGER.debug("Finished executing task {}", spec.taskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 57f23cf31b19..44703bf673fd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,6 +1,7 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.config.WorkerMode; @@ -14,7 +15,7 @@ public class WorkerContext { private UniqueId workerId; - private ThreadLocal currentTaskId; + private ThreadLocal currentTaskId; /** * Number of objects that have been put from current task. @@ -46,17 +47,17 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) mainThreadId = Thread.currentThread().getId(); taskIndex = ThreadLocal.withInitial(() -> 0); putIndex = ThreadLocal.withInitial(() -> 0); - currentTaskId = ThreadLocal.withInitial(UniqueId::randomId); + currentTaskId = ThreadLocal.withInitial(TaskId::randomId); this.runMode = runMode; currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { workerId = driverId; - currentTaskId.set(UniqueId.randomId()); + currentTaskId.set(TaskId.randomId()); currentDriverId = driverId; } else { workerId = UniqueId.randomId(); - this.currentTaskId.set(UniqueId.NIL); + this.currentTaskId.set(TaskId.NIL); this.currentDriverId = UniqueId.NIL; } } @@ -65,7 +66,7 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) * @return For the main thread, this method returns the ID of this worker's current running task; * for other threads, this method returns a random ID. */ - public UniqueId getCurrentTaskId() { + public TaskId getCurrentTaskId() { return currentTaskId.get(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index a627f200a0e6..431b48ded58c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -9,12 +9,15 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.id.BaseId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.generated.ActorCheckpointIdData; import org.ray.runtime.generated.ClientTableData; +import org.ray.runtime.generated.EntryType; import org.ray.runtime.generated.TablePrefix; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,7 +66,7 @@ public List getAllNodeInfo() { ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); - if (data.isInsertion()) { + if (data.entryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. @@ -72,12 +75,24 @@ public List getAllNodeInfo() { for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); } - NodeInfo nodeInfo = new NodeInfo( clientId, data.nodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); + } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { + Preconditions.checkState(clients.containsKey(clientId)); + NodeInfo nodeInfo = clients.get(clientId); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + } + } else if (data.entryType() == EntryType.RES_DELETE) { + Preconditions.checkState(clients.containsKey(clientId)); + NodeInfo nodeInfo = clients.get(clientId); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + } } else { // Code path of node deletion. + Preconditions.checkState(data.entryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -99,7 +114,7 @@ public boolean actorExists(UniqueId actorId) { /** * Query whether the raylet task exists in Gcs. */ - public boolean rayletTaskExistsInGcs(UniqueId taskId) { + public boolean rayletTaskExistsInGcs(TaskId taskId) { byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); @@ -119,7 +134,7 @@ public List getCheckpointsForActor(UniqueId actorId) { if (result != null) { ActorCheckpointIdData data = ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = UniqueIdUtil.getUniqueIdsFromByteBuffer( + UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( data.checkpointIdsAsByteBuffer()); for (int i = 0; i < checkpointIds.length; i++) { @@ -130,8 +145,8 @@ public List getCheckpointsForActor(UniqueId actorId) { return checkpoints; } - private RedisClient getShardClient(UniqueId key) { - return shards.get((int) Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(key), + private RedisClient getShardClient(BaseId key) { + return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key), shards.size())); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java index 4b80d3e4c276..f3d64c8340a2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java @@ -9,7 +9,7 @@ import java.util.stream.Collectors; import org.apache.arrow.plasma.ObjectStoreLink; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.RayDevRuntime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,16 +24,16 @@ public class MockObjectStore implements ObjectStoreLink { private static final int GET_CHECK_INTERVAL_MS = 100; private final RayDevRuntime runtime; - private final Map data = new ConcurrentHashMap<>(); - private final Map metadata = new ConcurrentHashMap<>(); - private final List> objectPutCallbacks; + private final Map data = new ConcurrentHashMap<>(); + private final Map metadata = new ConcurrentHashMap<>(); + private final List> objectPutCallbacks; public MockObjectStore(RayDevRuntime runtime) { this.runtime = runtime; this.objectPutCallbacks = new ArrayList<>(); } - public void addObjectPutCallback(Consumer callback) { + public void addObjectPutCallback(Consumer callback) { this.objectPutCallbacks.add(callback); } @@ -44,13 +44,12 @@ public void put(byte[] objectId, byte[] value, byte[] metadataValue) { .error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value)); System.exit(-1); } - UniqueId uniqueId = new UniqueId(objectId); - data.put(uniqueId, value); + ObjectId id = new ObjectId(objectId); + data.put(id, value); if (metadataValue != null) { - metadata.put(uniqueId, metadataValue); + metadata.put(id, metadataValue); } - UniqueId id = new UniqueId(objectId); - for (Consumer callback : objectPutCallbacks) { + for (Consumer callback : objectPutCallbacks) { callback.accept(id); } } @@ -85,7 +84,7 @@ public List get(byte[][] objectIds, int timeoutMs) { } ready = 0; for (byte[] id : objectIds) { - if (data.containsKey(new UniqueId(id))) { + if (data.containsKey(new ObjectId(id))) { ready += 1; } } @@ -93,8 +92,8 @@ public List get(byte[][] objectIds, int timeoutMs) { } ArrayList rets = new ArrayList<>(); for (byte[] objId : objectIds) { - UniqueId uniqueId = new UniqueId(objId); - rets.add(new ObjectStoreData(metadata.get(uniqueId), data.get(uniqueId))); + ObjectId objectId = new ObjectId(objId); + rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId))); } return rets; } @@ -121,7 +120,7 @@ public void delete(byte[] objectId) { @Override public boolean contains(byte[] objectId) { - return data.containsKey(new UniqueId(objectId)); + return data.containsKey(new ObjectId(objectId)); } private String logPrefix() { @@ -138,11 +137,11 @@ private String getUserTrace() { return stes[k].getFileName() + ":" + stes[k].getLineNumber(); } - public boolean isObjectReady(UniqueId id) { + public boolean isObjectReady(ObjectId id) { return data.containsKey(id); } - public void free(UniqueId id) { + public void free(ObjectId id) { data.remove(id); metadata.remove(id); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 64b9e2b73a9f..f9e310249a35 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -12,13 +12,13 @@ import org.ray.api.exception.RayException; import org.ray.api.exception.RayWorkerException; import org.ray.api.exception.UnreconstructableException; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; -import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,7 +61,7 @@ public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) { * @param Type of the object. * @return The GetResult object. */ - public GetResult get(UniqueId id, int timeoutMs) { + public GetResult get(ObjectId id, int timeoutMs) { List> list = get(ImmutableList.of(id), timeoutMs); return list.get(0); } @@ -74,8 +74,8 @@ public GetResult get(UniqueId id, int timeoutMs) { * @param Type of these objects. * @return A list of GetResult objects. */ - public List> get(List ids, int timeoutMs) { - byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids); + public List> get(List ids, int timeoutMs) { + byte[][] binaryIds = IdUtil.getIdBytes(ids); List dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs); List> results = new ArrayList<>(); @@ -114,7 +114,7 @@ public List> get(List ids, int timeoutMs) { } @SuppressWarnings("unchecked") - private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) { + private GetResult deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) { if (Arrays.equals(meta, RAW_TYPE_META)) { return (GetResult) new GetResult<>(true, data, null); } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { @@ -133,7 +133,7 @@ private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId * @param id Id of the object. * @param object The object to put. */ - public void put(UniqueId id, Object object) { + public void put(ObjectId id, Object object) { try { if (object instanceof byte[]) { // If the object is a byte array, skip serializing it and use a special metadata to @@ -153,7 +153,7 @@ public void put(UniqueId id, Object object) { * @param id Id of the object. * @param serializedObject The serialized object to put. */ - public void putSerialized(UniqueId id, byte[] serializedObject) { + public void putSerialized(ObjectId id, byte[] serializedObject) { try { objectStore.get().put(id.getBytes(), serializedObject, null); } catch (DuplicateObjectException e) { diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 385431c7055f..fe1f61d0bc11 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -17,6 +17,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.Worker; @@ -33,7 +35,7 @@ public class MockRayletClient implements RayletClient { private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class); - private final Map> waitingTasks = new ConcurrentHashMap<>(); + private final Map> waitingTasks = new ConcurrentHashMap<>(); private final MockObjectStore store; private final RayDevRuntime runtime; private final ExecutorService exec; @@ -52,7 +54,7 @@ public MockRayletClient(RayDevRuntime runtime, int numberThreads) { currentWorker = new ThreadLocal<>(); } - public synchronized void onObjectPut(UniqueId id) { + public synchronized void onObjectPut(ObjectId id) { Set tasks = waitingTasks.get(id); if (tasks != null) { waitingTasks.remove(id); @@ -98,7 +100,7 @@ private void returnWorker(Worker worker) { @Override public synchronized void submitTask(TaskSpec task) { LOGGER.debug("Submitting task: {}.", task); - Set unreadyObjects = getUnreadyObjects(task); + Set unreadyObjects = getUnreadyObjects(task); if (unreadyObjects.isEmpty()) { // If all dependencies are ready, execute this task. exec.submit(() -> { @@ -109,7 +111,7 @@ public synchronized void submitTask(TaskSpec task) { // put the dummy object in object store, so those tasks which depends on it // can be executed. if (task.isActorCreationTask() || task.isActorTask()) { - UniqueId[] returnIds = task.returnIds; + ObjectId[] returnIds = task.returnIds; store.put(returnIds[returnIds.length - 1].getBytes(), new byte[]{}, new byte[]{}); } @@ -119,14 +121,14 @@ public synchronized void submitTask(TaskSpec task) { }); } else { // If some dependencies aren't ready yet, put this task in waiting list. - for (UniqueId id : unreadyObjects) { + for (ObjectId id : unreadyObjects) { waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task); } } } - private Set getUnreadyObjects(TaskSpec spec) { - Set unreadyObjects = new HashSet<>(); + private Set getUnreadyObjects(TaskSpec spec) { + Set unreadyObjects = new HashSet<>(); // Check whether task arguments are ready. for (FunctionArg arg : spec.args) { if (arg.id != null) { @@ -136,7 +138,7 @@ private Set getUnreadyObjects(TaskSpec spec) { } } // Check whether task dependencies are ready. - for (UniqueId id : spec.getExecutionDependencies()) { + for (ObjectId id : spec.getExecutionDependencies()) { if (!store.isObjectReady(id)) { unreadyObjects.add(id); } @@ -151,24 +153,24 @@ public TaskSpec getTask() { } @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + TaskId currentTaskId) { } @Override - public void notifyUnblocked(UniqueId currentTaskId) { + public void notifyUnblocked(TaskId currentTaskId) { } @Override - public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) { - return UniqueId.randomId(); + public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { + return TaskId.randomId(); } @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId) { + timeoutMs, TaskId currentTaskId) { if (waitFor == null || waitFor.isEmpty()) { return new WaitResult<>(ImmutableList.of(), ImmutableList.of()); } @@ -191,9 +193,9 @@ public WaitResult wait(List> waitFor, int numReturns, int } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, + public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - for (UniqueId id : objectIds) { + for (ObjectId id : objectIds) { store.free(id); } } @@ -209,6 +211,11 @@ public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpoi throw new NotImplementedException("Not implemented."); } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + LOGGER.error("Not implemented under SINGLE_PROCESS mode."); + } + @Override public void destroy() { exec.shutdown(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index fc6fc75b0fbd..4a78fde9430e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,6 +3,8 @@ import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.task.TaskSpec; @@ -15,20 +17,22 @@ public interface RayletClient { TaskSpec getTask(); - void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId); + void fetchOrReconstruct(List objectIds, boolean fetchOnly, TaskId currentTaskId); - void notifyUnblocked(UniqueId currentTaskId); + void notifyUnblocked(TaskId currentTaskId); - UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex); + TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex); WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId); + timeoutMs, TaskId currentTaskId); - void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); UniqueId prepareCheckpoint(UniqueId actorId); void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId); + void setResource(String resourceName, double capacity, UniqueId nodeId); + void destroy(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 0ed1f9c86fbf..01b9e4675016 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -11,6 +11,8 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.generated.Arg; @@ -20,7 +22,7 @@ import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,18 +52,18 @@ public RayletClientImpl(String schedulerSockName, UniqueId clientId, @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId) { + timeoutMs, TaskId currentTaskId) { Preconditions.checkNotNull(waitFor); if (waitFor.isEmpty()) { return new WaitResult<>(new ArrayList<>(), new ArrayList<>()); } - List ids = new ArrayList<>(); + List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); } - boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids), + boolean[] ready = nativeWaitObject(client, IdUtil.getIdBytes(ids), numReturns, timeoutMs, false, currentTaskId.getBytes()); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); @@ -101,31 +103,31 @@ public TaskSpec getTask() { } @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + TaskId currentTaskId) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocked on objects for task {}, object IDs are {}", - UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); + objectIds.get(0).getTaskId(), objectIds); } - nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + nativeFetchOrReconstruct(client, IdUtil.getIdBytes(objectIds), fetchOnly, currentTaskId.getBytes()); } @Override - public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) { + public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex); - return new UniqueId(bytes); + return new TaskId(bytes); } @Override - public void notifyUnblocked(UniqueId currentTaskId) { + public void notifyUnblocked(TaskId currentTaskId) { nativeNotifyUnblocked(client, currentTaskId.getBytes()); } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, + public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds); + byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds); nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks); } @@ -144,17 +146,18 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer()); - UniqueId taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer()); - UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); + TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer()); + TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); int parentCounter = info.parentCounter(); UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); int maxActorReconstructions = info.maxActorReconstructions(); UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer()); UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer()); int actorCounter = info.actorCounter(); + int numReturns = info.numReturns(); // Deserialize new actor handles - UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer( + UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer( info.newActorHandlesAsByteBuffer()); // Deserialize args @@ -166,8 +169,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { if (objectIdsLength > 0) { Preconditions.checkArgument(objectIdsLength == 1, "This arg has more than one id: {}", objectIdsLength); - UniqueId id = UniqueIdUtil.getUniqueIdsFromByteBuffer(arg.objectIdsAsByteBuffer())[0]; - args[i] = FunctionArg.passByReference(id); + args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer())); } else { ByteBuffer lbb = arg.dataAsByteBuffer(); Preconditions.checkState(lbb != null && lbb.remaining() > 0); @@ -176,8 +178,6 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { args[i] = FunctionArg.passByValue(data); } } - // Deserialize return ids - UniqueId[] returnIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(info.returnsAsByteBuffer()); // Deserialize required resources; Map resources = new HashMap<>(); @@ -192,7 +192,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { ); return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -210,10 +210,11 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); final int actorCounter = task.actorCounter; + final int numReturnsOffset = task.numReturns; // Serialize the new actor handles. int newActorHandlesOffset - = fbb.createString(UniqueIdUtil.concatUniqueIds(task.newActorHandles)); + = fbb.createString(IdUtil.concatIds(task.newActorHandles)); // Serialize args int[] argsOffsets = new int[task.args.length]; @@ -222,7 +223,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { int dataOffset = 0; if (task.args[i].id != null) { objectIdOffset = fbb.createString( - UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id})); + IdUtil.concatIds(new ObjectId[]{task.args[i].id})); } else { objectIdOffset = fbb.createString(""); } @@ -233,9 +234,6 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { } int argsOffset = fbb.createVectorOfTables(argsOffsets); - // Serialize returns - int returnsOffset = fbb.createString(UniqueIdUtil.concatUniqueIds(task.returnIds)); - // Serialize required resources // The required_resources vector indicates the quantities of the different // resources required by this task. The index in this vector corresponds to @@ -291,7 +289,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { actorCounter, newActorHandlesOffset, argsOffset, - returnsOffset, + numReturnsOffset, requiredResourcesOffset, requiredPlacementResourcesOffset, language, @@ -308,6 +306,10 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { return buffer; } + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + nativeSetResource(client, resourceName, capacity, nodeId.getBytes()); + } + public void destroy() { nativeDestroy(client); } @@ -357,4 +359,7 @@ private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, byte[] checkpointId); + + private static native void nativeSetResource(long conn, String resourceName, double capacity, + byte[] nodeId) throws RayException; } diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java index 6fd3ea0e76f9..211411906fdc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java @@ -15,6 +15,9 @@ public class DefaultWorker { public static void main(String[] args) { try { System.setProperty("ray.worker.mode", "WORKER"); + Thread.setDefaultUncaughtExceptionHandler((Thread t, Throwable e) -> { + LOGGER.error("Uncaught worker exception in thread {}: {}", t, e); + }); Ray.init(); LOGGER.info("Worker started."); ((AbstractRayRuntime)Ray.internal()).loop(); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 1da6dec31eb1..52447cf79334 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -5,7 +5,7 @@ import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; @@ -24,7 +24,7 @@ public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { FunctionArg[] ret = new FunctionArg[args.length]; for (int i = 0; i < ret.length; i++) { Object arg = args[i]; - UniqueId id = null; + ObjectId id = null; byte[] data = null; if (arg == null) { data = Serializer.encode(null); @@ -59,7 +59,7 @@ public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { */ public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) { Object[] realArgs = new Object[task.args.length]; - List idsToFetch = new ArrayList<>(); + List idsToFetch = new ArrayList<>(); List indices = new ArrayList<>(); for (int i = 0; i < task.args.length; i++) { FunctionArg arg = task.args[i]; diff --git a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java index 19a16e872b55..95bdcb0da653 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java @@ -1,6 +1,6 @@ package org.ray.runtime.task; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Represents a function argument in task spec. @@ -12,13 +12,13 @@ public class FunctionArg { /** * The id of this argument (passed by reference). */ - public final UniqueId id; + public final ObjectId id; /** * Serialized data of this argument (passed by value). */ public final byte[] data; - private FunctionArg(UniqueId id, byte[] data) { + private FunctionArg(ObjectId id, byte[] data) { this.id = id; this.data = data; } @@ -26,7 +26,7 @@ private FunctionArg(UniqueId id, byte[] data) { /** * Create a FunctionArg that will be passed by reference. */ - public static FunctionArg passByReference(UniqueId id) { + public static FunctionArg passByReference(ObjectId id) { return new FunctionArg(id, null); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index d8f715ce6a76..3473a9bdb3cc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -5,10 +5,13 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.PyFunctionDescriptor; +import org.ray.runtime.util.IdUtil; /** * Represents necessary information of a task for scheduling and executing. @@ -19,10 +22,10 @@ public class TaskSpec { public final UniqueId driverId; // Task ID of the task. - public final UniqueId taskId; + public final TaskId taskId; // Task ID of the parent task. - public final UniqueId parentTaskId; + public final TaskId parentTaskId; // A count of the number of tasks submitted by the parent task before this one. public final int parentCounter; @@ -48,8 +51,11 @@ public class TaskSpec { // Task arguments. public final FunctionArg[] args; - // return ids - public final UniqueId[] returnIds; + // number of return objects. + public final int numReturns; + + // returns ids. + public final ObjectId[] returnIds; // The task's resource demands. public final Map resources; @@ -62,7 +68,7 @@ public class TaskSpec { // is Python, the type is PyFunctionDescriptor. private final FunctionDescriptor functionDescriptor; - private List executionDependencies; + private List executionDependencies; public boolean isActorTask() { return !actorId.isNil(); @@ -74,8 +80,8 @@ public boolean isActorCreationTask() { public TaskSpec( UniqueId driverId, - UniqueId taskId, - UniqueId parentTaskId, + TaskId taskId, + TaskId parentTaskId, int parentCounter, UniqueId actorCreationId, int maxActorReconstructions, @@ -84,7 +90,7 @@ public TaskSpec( int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args, - UniqueId[] returnIds, + int numReturns, Map resources, TaskLanguage language, FunctionDescriptor functionDescriptor) { @@ -99,7 +105,11 @@ public TaskSpec( this.actorCounter = actorCounter; this.newActorHandles = newActorHandles; this.args = args; - this.returnIds = returnIds; + this.numReturns = numReturns; + returnIds = new ObjectId[numReturns]; + for (int i = 0; i < numReturns; ++i) { + returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); + } this.resources = resources; this.language = language; if (language == TaskLanguage.JAVA) { @@ -125,7 +135,7 @@ public PyFunctionDescriptor getPyFunctionDescriptor() { return (PyFunctionDescriptor) functionDescriptor; } - public List getExecutionDependencies() { + public List getExecutionDependencies() { return executionDependencies; } @@ -143,7 +153,7 @@ public String toString() { ", actorCounter=" + actorCounter + ", newActorHandles=" + Arrays.toString(newActorHandles) + ", args=" + Arrays.toString(args) + - ", returnIds=" + Arrays.toString(returnIds) + + ", numReturns=" + numReturns + ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + diff --git a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java similarity index 64% rename from java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java rename to java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index fa8b51ffaac8..62c56d17ceed 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -3,19 +3,20 @@ import com.google.common.base.Preconditions; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.Arrays; import java.util.List; +import org.ray.api.id.BaseId; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; /** - * Helper method for UniqueId. + * Helper method for different Ids. * Note: any changes to these methods must be synced with C++ helper functions * in src/ray/id.h */ -public class UniqueIdUtil { - public static final int OBJECT_INDEX_POS = 0; - public static final int OBJECT_INDEX_LENGTH = 4; +public class IdUtil { + public static final int OBJECT_INDEX_POS = 16; /** * Compute the object ID of an object returned by the task. @@ -24,7 +25,7 @@ public class UniqueIdUtil { * @param returnIndex What number return value this object is in the task. * @return The computed object ID. */ - public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) { + public static ObjectId computeReturnId(TaskId taskId, int returnIndex) { return computeObjectId(taskId, returnIndex); } @@ -34,14 +35,13 @@ public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) { * @param index The index which can distinguish different objects in one task. * @return The computed object ID. */ - private static UniqueId computeObjectId(UniqueId taskId, int index) { - byte[] objId = new byte[UniqueId.LENGTH]; - System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH); - ByteBuffer wbb = ByteBuffer.wrap(objId); + private static ObjectId computeObjectId(TaskId taskId, int index) { + byte[] bytes = new byte[ObjectId.LENGTH]; + System.arraycopy(taskId.getBytes(), 0, bytes, 0, taskId.size()); + ByteBuffer wbb = ByteBuffer.wrap(bytes); wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putInt(UniqueIdUtil.OBJECT_INDEX_POS, index); - - return new UniqueId(objId); + wbb.putInt(OBJECT_INDEX_POS, index); + return new ObjectId(bytes); } /** @@ -51,26 +51,11 @@ private static UniqueId computeObjectId(UniqueId taskId, int index) { * @param putIndex What number put this object was created by in the task. * @return The computed object ID. */ - public static UniqueId computePutId(UniqueId taskId, int putIndex) { + public static ObjectId computePutId(TaskId taskId, int putIndex) { // We multiply putIndex by -1 to distinguish from returnIndex. return computeObjectId(taskId, -1 * putIndex); } - /** - * Compute the task ID of the task that created the object. - * - * @param objectId The object ID. - * @return The task ID of the task that created this object. - */ - public static UniqueId computeTaskId(UniqueId objectId) { - byte[] taskId = new byte[UniqueId.LENGTH]; - System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH); - Arrays.fill(taskId, UniqueIdUtil.OBJECT_INDEX_POS, - UniqueIdUtil.OBJECT_INDEX_POS + UniqueIdUtil.OBJECT_INDEX_LENGTH, (byte) 0); - - return new UniqueId(taskId); - } - /** * Generate the return ids of a task. * @@ -78,15 +63,15 @@ public static UniqueId computeTaskId(UniqueId objectId) { * @param numReturns The number of returnIds. * @return The Return Ids of this task. */ - public static UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { - UniqueId[] ret = new UniqueId[numReturns]; + public static ObjectId[] genReturnIds(TaskId taskId, int numReturns) { + ObjectId[] ret = new ObjectId[numReturns]; for (int i = 0; i < numReturns; i++) { - ret[i] = UniqueIdUtil.computeReturnId(taskId, i + 1); + ret[i] = IdUtil.computeReturnId(taskId, i + 1); } return ret; } - public static byte[][] getIdBytes(List objectIds) { + public static byte[][] getIdBytes(List objectIds) { int size = objectIds.size(); byte[][] ids = new byte[size][]; for (int i = 0; i < size; i++) { @@ -95,6 +80,24 @@ public static byte[][] getIdBytes(List objectIds) { return ids; } + public static byte[][] getByteListFromByteBuffer(ByteBuffer byteBufferOfIds, int length) { + Preconditions.checkArgument(byteBufferOfIds != null); + + byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; + byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); + + int count = bytesOfIds.length / length; + byte[][] idBytes = new byte[count][]; + + for (int i = 0; i < count; ++i) { + byte[] id = new byte[length]; + System.arraycopy(bytesOfIds, i * length, id, 0, length); + idBytes[i] = id; + } + + return idBytes; + } + /** * Get unique IDs from concatenated ByteBuffer. * @@ -102,21 +105,31 @@ public static byte[][] getIdBytes(List objectIds) { * @return The array of unique IDs. */ public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { - Preconditions.checkArgument(byteBufferOfIds != null); + byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); + UniqueId[] uniqueIds = new UniqueId[idBytes.length]; - byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; - byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); + for (int i = 0; i < idBytes.length; ++i) { + uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); + } + + return uniqueIds; + } - int count = bytesOfIds.length / UniqueId.LENGTH; - UniqueId[] uniqueIds = new UniqueId[count]; + /** + * Get object IDs from concatenated ByteBuffer. + * + * @param byteBufferOfIds The ByteBuffer concatenated from IDs. + * @return The array of object IDs. + */ + public static ObjectId[] getObjectIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { + byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); + ObjectId[] objectIds = new ObjectId[idBytes.length]; - for (int i = 0; i < count; ++i) { - byte[] id = new byte[UniqueId.LENGTH]; - System.arraycopy(bytesOfIds, i * UniqueId.LENGTH, id, 0, UniqueId.LENGTH); - uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(id)); + for (int i = 0; i < idBytes.length; ++i) { + objectIds[i] = ObjectId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); } - return uniqueIds; + return objectIds; } /** @@ -125,11 +138,15 @@ public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) * @param ids The array of IDs that will be concatenated. * @return A ByteBuffer that contains bytes of concatenated IDs. */ - public static ByteBuffer concatUniqueIds(UniqueId[] ids) { - byte[] bytesOfIds = new byte[UniqueId.LENGTH * ids.length]; + public static ByteBuffer concatIds(T[] ids) { + int length = 0; + if (ids != null && ids.length != 0) { + length = ids[0].size() * ids.length; + } + byte[] bytesOfIds = new byte[length]; for (int i = 0; i < ids.length; ++i) { System.arraycopy(ids[i].getBytes(), 0, bytesOfIds, - i * UniqueId.LENGTH, UniqueId.LENGTH); + i * ids[i].size(), ids[i].size()); } return ByteBuffer.wrap(bytesOfIds); @@ -139,8 +156,8 @@ public static ByteBuffer concatUniqueIds(UniqueId[] ids) { /** * Compute the murmur hash code of this ID. */ - public static long murmurHashCode(UniqueId id) { - return murmurHash64A(id.getBytes(), UniqueId.LENGTH, 0); + public static long murmurHashCode(BaseId id) { + return murmurHash64A(id.getBytes(), id.size(), 0); } /** diff --git a/java/streaming/pom.xml b/java/streaming/pom.xml index c95976373d3c..382233fb02af 100644 --- a/java/streaming/pom.xml +++ b/java/streaming/pom.xml @@ -1,4 +1,5 @@ + @@ -26,17 +27,30 @@ ray-runtime ${project.version} - - org.slf4j - slf4j-log4j12 - - - com.google.guava - guava - - - org.testng - testng - + + com.beust + jcommander + 1.72 + + + com.google.guava + guava + 27.0.1-jre + + + org.slf4j + slf4j-api + 1.7.25 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + + + org.testng + testng + 6.9.9 + diff --git a/java/streaming/pom_template.xml b/java/streaming/pom_template.xml new file mode 100644 index 000000000000..3551e7443e5c --- /dev/null +++ b/java/streaming/pom_template.xml @@ -0,0 +1,32 @@ + +{auto_gen_header} + + + org.ray + ray-superpom + 0.1-SNAPSHOT + + 4.0.0 + + streaming + ray streaming + ray streaming + + jar + + + + org.ray + ray-api + ${project.version} + + + org.ray + ray-runtime + ${project.version} + +{generated_bzl_deps} + + diff --git a/java/test.sh b/java/test.sh index 48242f39888b..ba728f14bf38 100755 --- a/java/test.sh +++ b/java/test.sh @@ -38,5 +38,5 @@ popd pushd $ROOT_DIR echo "Testing maven install." -mvn clean install -Dmaven.test.skip +mvn clean install -DskipTests popd diff --git a/java/test/pom.xml b/java/test/pom.xml index afb8da564293..10f7ea4b3313 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -1,5 +1,5 @@ - + @@ -22,34 +22,46 @@ ray-api ${project.version} - org.ray ray-runtime ${project.version} - - - org.testng - testng - - - - com.google.guava - guava - + + com.google.guava + guava + 27.0.1-jre + + + commons-io + commons-io + 2.5 + + + org.apache.commons + commons-lang3 + 3.4 + + + org.slf4j + slf4j-api + 1.7.25 + + + org.testng + testng + 6.9.9 + org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M3 + 2.21.0 - false false ${basedir}/src/main/java/ - ${basedir}/src/main/resources/ ${project.build.directory}/classes/ diff --git a/java/test/pom_template.xml b/java/test/pom_template.xml index f67e735a5b80..9b8b3684f297 100644 --- a/java/test/pom_template.xml +++ b/java/test/pom_template.xml @@ -1,5 +1,5 @@ - +{auto_gen_header} @@ -22,14 +22,12 @@ ray-api ${project.version} - org.ray ray-runtime ${project.version} - - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index b588822712c5..227ff7e5865b 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -6,7 +6,7 @@ import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.exception.RayException; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.RayObjectImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +20,7 @@ public class ClientExceptionTest extends BaseTest { @Test public void testWaitAndCrash() { TestUtils.skipTestUnderSingleProcess(); - UniqueId randomId = UniqueId.randomId(); + ObjectId randomId = ObjectId.randomId(); RayObject notExisting = new RayObjectImpl(randomId); Thread thread = new Thread(() -> { diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java new file mode 100644 index 000000000000..ffda0732287e --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -0,0 +1,44 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; +import org.ray.api.Ray; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.WaitResult; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.CallOptions; +import org.ray.api.runtimecontext.NodeInfo; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class DynamicResourceTest extends BaseTest { + + @RayRemote + public static String sayHi() { + return "hi"; + } + + @Test + public void testSetResource() { + TestUtils.skipTestUnderSingleProcess(); + CallOptions op1 = new CallOptions(ImmutableMap.of("A", 10.0)); + RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); + WaitResult result = Ray.wait(ImmutableList.of(obj), 1, 1000); + Assert.assertEquals(result.getReady().size(), 0); + + Ray.setResource("A", 10.0); + + // Assert node info. + List nodes = Ray.getRuntimeContext().getAllNodeInfo(); + Assert.assertEquals(nodes.size(), 1); + Assert.assertEquals(nodes.get(0).resources.get("A"), 10.0); + + // Assert ray call result. + result = Ray.wait(ImmutableList.of(obj), 1, 1000); + Assert.assertEquals(result.getReady().size(), 1); + Assert.assertEquals(Ray.get(obj.getId()), "hi"); + } + +} diff --git a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java index eaa99a2892fd..be584ba6d1be 100644 --- a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,7 +23,7 @@ public void testPutAndGet() { @Test public void testGetMultipleObjects() { List ints = ImmutableList.of(1, 2, 3, 4, 5); - List ids = ints.stream().map(obj -> Ray.put(obj).getId()) + List ids = ints.stream().map(obj -> Ray.put(obj).getId()) .collect(Collectors.toList()); Assert.assertEquals(ints, Ray.get(ids)); } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 1e344e5028b3..3c36f2201a8b 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -6,7 +6,6 @@ import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.util.UniqueIdUtil; import org.testng.Assert; import org.testng.annotations.Test; @@ -38,7 +37,7 @@ public void testDeleteCreatingTasks() { final boolean result = TestUtils.waitForCondition( () -> !(((AbstractRayRuntime)Ray.internal()).getGcsClient()) - .rayletTaskExistsInGcs(UniqueIdUtil.computeTaskId(helloId.getId())), 50); + .rayletTaskExistsInGcs(helloId.getId().getTaskId()), 50); Assert.assertTrue(result); } diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index b5bf1356ea4f..e2efecbf222e 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -7,7 +7,7 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,7 +23,7 @@ public void testSubmittingTasks() { for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) { int numTasks = 1000 / numIterations; for (int i = 0; i < numIterations; i++) { - List resultIds = new ArrayList<>(); + List resultIds = new ArrayList<>(); for (int j = 0; j < numTasks; j++) { resultIds.add(Ray.call(StressTest::echo, 1).getId()); } @@ -60,7 +60,7 @@ public Worker(RayActor actor) { } public int ping(int n) { - List objectIds = new ArrayList<>(); + List objectIds = new ArrayList<>(); for (int i = 0; i < n; i++) { objectIds.add(Ray.call(Actor::ping, actor).getId()); } @@ -76,7 +76,7 @@ public int ping(int n) { public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(Actor::new); - List objectIds = new ArrayList<>(); + List objectIds = new ArrayList<>(); for (int i = 0; i < 10; i++) { RayActor worker = Ray.createActor(Worker::new, actor); objectIds.add(Ray.call(Worker::ping, worker, 100).getId()); diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index 5b3d773dbf2c..cc1bc7a53f3e 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -3,8 +3,10 @@ import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.testng.Assert; import org.testng.annotations.Test; @@ -42,7 +44,7 @@ public void testConstructUniqueId() { // Test `genNil()` - UniqueId id6 = UniqueId.genNil(); + UniqueId id6 = UniqueId.NIL; Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } @@ -50,33 +52,33 @@ public void testConstructUniqueId() { @Test public void testComputeReturnId() { // Mock a taskId, and the lowest 4 bytes should be 0. - UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); - UniqueId returnId = UniqueIdUtil.computeReturnId(taskId, 1); - Assert.assertEquals("01000000123456789abcdef123456789abcdef00", returnId.toString()); + ObjectId returnId = IdUtil.computeReturnId(taskId, 1); + Assert.assertEquals("123456789abcdef123456789abcdef0001000000", returnId.toString()); - returnId = UniqueIdUtil.computeReturnId(taskId, 0x01020304); - Assert.assertEquals("04030201123456789abcdef123456789abcdef00", returnId.toString()); + returnId = IdUtil.computeReturnId(taskId, 0x01020304); + Assert.assertEquals("123456789abcdef123456789abcdef0004030201", returnId.toString()); } @Test public void testComputeTaskId() { - UniqueId objId = UniqueId.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); - UniqueId taskId = UniqueIdUtil.computeTaskId(objId); + ObjectId objId = ObjectId.fromHexString("123456789ABCDEF123456789ABCDEF0034421980"); + TaskId taskId = objId.getTaskId(); - Assert.assertEquals("00000000123456789abcdef123456789abcdef00", taskId.toString()); + Assert.assertEquals("123456789abcdef123456789abcdef00", taskId.toString()); } @Test public void testComputePutId() { // Mock a taskId, the lowest 4 bytes should be 0. - UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); - UniqueId putId = UniqueIdUtil.computePutId(taskId, 1); - Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); + ObjectId putId = IdUtil.computePutId(taskId, 1); + Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FFFFFFFF".toLowerCase(), putId.toString()); - putId = UniqueIdUtil.computePutId(taskId, 0x01020304); - Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); + putId = IdUtil.computePutId(taskId, 0x01020304); + Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FCFCFDFE".toLowerCase(), putId.toString()); } @Test @@ -87,8 +89,8 @@ public void testUniqueIdsAndByteBufferInterConversion() { ids[i] = UniqueId.randomId(); } - ByteBuffer temp = UniqueIdUtil.concatUniqueIds(ids); - UniqueId[] res = UniqueIdUtil.getUniqueIdsFromByteBuffer(temp); + ByteBuffer temp = IdUtil.concatIds(ids); + UniqueId[] res = IdUtil.getUniqueIdsFromByteBuffer(temp); for (int i = 0; i < len; ++i) { Assert.assertEquals(ids[i], res[i]); @@ -98,8 +100,28 @@ public void testUniqueIdsAndByteBufferInterConversion() { @Test void testMurmurHash() { UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); - long remainder = Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(id), 1000000000); + long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); Assert.assertEquals(remainder, 787616861); } + @Test + void testConcateIds() { + String taskHexStr = "123456789ABCDEF123456789ABCDEF00"; + String objectHexStr = taskHexStr + "01020304"; + ObjectId objectId1 = ObjectId.fromHexString(objectHexStr); + ObjectId objectId2 = ObjectId.fromHexString(objectHexStr); + TaskId[] taskIds = new TaskId[2]; + taskIds[0] = objectId1.getTaskId(); + taskIds[1] = objectId2.getTaskId(); + ObjectId[] objectIds = new ObjectId[2]; + objectIds[0] = objectId1; + objectIds[1] = objectId2; + String taskHexCompareStr = taskHexStr + taskHexStr; + String objectHexCompareStr = objectHexStr + objectHexStr; + Assert.assertEquals(DatatypeConverter.printHexBinary( + IdUtil.concatIds(taskIds).array()), taskHexCompareStr); + Assert.assertEquals(DatatypeConverter.printHexBinary( + IdUtil.concatIds(objectIds).array()), objectHexCompareStr); + } + } diff --git a/java/tutorial/pom.xml b/java/tutorial/pom.xml index 48a03dc1ca8e..b0e78b40e15e 100644 --- a/java/tutorial/pom.xml +++ b/java/tutorial/pom.xml @@ -1,4 +1,5 @@ + ray-runtime ${project.version} + + com.google.guava + guava + 27.0.1-jre + diff --git a/java/tutorial/pom_template.xml b/java/tutorial/pom_template.xml index 3ced33cf3ac2..0f7b2fdf4693 100644 --- a/java/tutorial/pom_template.xml +++ b/java/tutorial/pom_template.xml @@ -1,4 +1,5 @@ +{auto_gen_header} ray-runtime ${project.version} - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/kubernetes/example.py b/kubernetes/example.py index e80a6b6c9b30..5ba0272c73e5 100644 --- a/kubernetes/example.py +++ b/kubernetes/example.py @@ -14,7 +14,7 @@ # Wait for all 4 nodes to join the cluster. while True: - num_nodes = len(ray.global_state.client_table()) + num_nodes = len(ray.nodes()) if num_nodes < 4: print("{} nodes have joined so far. Waiting for more." .format(num_nodes)) diff --git a/kubernetes/submit.yaml b/kubernetes/submit.yaml index 80eeecd9751e..e6e66ae8d944 100644 --- a/kubernetes/submit.yaml +++ b/kubernetes/submit.yaml @@ -86,7 +86,7 @@ spec: spec: affinity: podAntiAffinity: - requiredDuringSchedulingIgnoreDuringExecution: + requiredDuringSchedulingIgnoredDuringExecution: - labelSelector: matchLabels: type: ray diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 1e382d1b9c2c..e1b65cdcf6c7 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -66,6 +66,9 @@ _config = _Config() from ray.profiling import profile # noqa: E402 +from ray.state import (global_state, nodes, tasks, objects, timeline, + object_transfer_timeline, cluster_resources, + available_resources, errors) # noqa: E402 from ray.worker import ( LOCAL_MODE, PYTHON_MODE, @@ -73,12 +76,10 @@ WORKER_MODE, connect, disconnect, - error_info, get, get_gpu_ids, get_resource_ids, get_webui_url, - global_state, init, is_initialized, put, @@ -95,9 +96,18 @@ from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. -__version__ = "0.7.0.dev3" +__version__ = "0.8.0.dev0" __all__ = [ + "global_state", + "nodes", + "tasks", + "objects", + "timeline", + "object_transfer_timeline", + "cluster_resources", + "available_resources", + "errors", "LOCAL_MODE", "PYTHON_MODE", "SCRIPT_MODE", @@ -108,12 +118,10 @@ "actor", "connect", "disconnect", - "error_info", "get", "get_gpu_ids", "get_resource_ids", "get_webui_url", - "global_state", "init", "internal", "is_initialized", diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 31937837c780..a5f106f1e911 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -32,6 +32,7 @@ from ray.includes.libraylet cimport ( from ray.includes.unique_ids cimport ( CActorCheckpointID, CObjectID, + CClientID, ) from ray.includes.task cimport CTaskSpecification from ray.includes.ray_config cimport RayConfig @@ -87,11 +88,11 @@ def compute_put_id(TaskID task_id, int64_t put_index): if put_index < 1 or put_index > kMaxTaskPuts: raise ValueError("The range of 'put_index' should be [1, %d]" % kMaxTaskPuts) - return ObjectID(ComputePutId(task_id.native(), put_index).binary()) + return ObjectID(CObjectID.for_put(task_id.native(), put_index).binary()) def compute_task_id(ObjectID object_id): - return TaskID(ComputeTaskId(object_id.native()).binary()) + return TaskID(object_id.native().task_id().binary()) cdef c_bool is_simple_value(value, int *num_elements_contained): @@ -368,6 +369,9 @@ cdef class RayletClient: check_status(self.client.get().NotifyActorResumedFromCheckpoint( actor_id.native(), checkpoint_id.native())) + def set_resource(self, basestring resource_name, double capacity, ClientID client_id): + self.client.get().SetResource(resource_name.encode("ascii"), capacity, CClientID.from_binary(client_id.binary())) + @property def language(self): return Language.from_native(self.client.get().GetLanguage()) diff --git a/python/ray/actor.py b/python/ray/actor.py index 7c24208028b4..65642d9928ee 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -17,7 +17,6 @@ import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray.utils import _random_string from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID, DriverID) @@ -187,8 +186,12 @@ class ActorClass(object): task. _resources: The default resources required by the actor creation task. _actor_method_cpus: The number of CPUs required by actor method tasks. - _exported: True if the actor class has been exported and false - otherwise. + _last_driver_id_exported_for: The ID of the driver ID of the last Ray + session during which this actor class definition was exported. This + is an imperfect mechanism used to determine if we need to export + the remote function again. It is imperfect in the sense that the + actor class definition could be exported multiple times by + different workers. _actor_methods: The actor methods. _method_decorators: Optional decorators that should be applied to the method invocation function before invoking the actor methods. These @@ -209,7 +212,7 @@ def __init__(self, modified_class, class_id, max_reconstructions, num_cpus, self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources - self._exported = False + self._last_driver_id_exported_for = None self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) @@ -308,7 +311,7 @@ def _remote(self, raise Exception("Actors cannot be created before ray.init() " "has been called.") - actor_id = ActorID(_random_string()) + actor_id = ActorID.from_random() # The actor cursor is a dummy object representing the most recent # actor method invocation. For each subsequent method invocation, # the current cursor should be added as a dependency, and then @@ -342,10 +345,15 @@ def _remote(self, *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. - if not self._exported: + if (self._last_driver_id_exported_for is None + or self._last_driver_id_exported_for != + worker.task_driver_id): + # If this actor class was exported in a previous session, we + # need to export this function again, because current GCS + # doesn't have it. + self._last_driver_id_exported_for = worker.task_driver_id worker.function_actor_manager.export_actor_class( self._modified_class, self._actor_method_names) - self._exported = True resources = ray.utils.resources_from_resource_arguments( cpus_to_use, self._num_gpus, self._resources, num_cpus, @@ -670,7 +678,7 @@ def _serialization_helper(self, ray_forking): # to release, since it could be unpickled and submit another # dependent task at any time. Therefore, we notify the backend of a # random handle ID that will never actually be used. - new_actor_handle_id = ActorHandleID(_random_string()) + new_actor_handle_id = ActorHandleID.from_random() # Notify the backend to expect this new actor handle. The backend will # not release the cursor for any new handles until the first task for # each of the new handles is submitted. @@ -780,7 +788,7 @@ def __ray_checkpoint__(self): Class.__module__ = cls.__module__ Class.__name__ = cls.__name__ - class_id = ActorClassID(_random_string()) + class_id = ActorClassID.from_random() return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus, resources) @@ -803,7 +811,7 @@ def exit_actor(): worker.raylet_client.disconnect() ray.disconnect() # Disconnect global state from GCS. - ray.global_state.disconnect() + ray.state.state.disconnect() sys.exit(0) assert False, "This process should have terminated." else: @@ -923,7 +931,7 @@ def get_checkpoints_for_actor(actor_id): """Get the available checkpoints for the given actor ID, return a list sorted by checkpoint timestamp in descending order. """ - checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id) + checkpoint_info = ray.state.state.actor_checkpoint_info(actor_id) if checkpoint_info is None: return [] checkpoints = [ diff --git a/python/ray/autoscaler/aws/development-example.yaml b/python/ray/autoscaler/aws/development-example.yaml index 0986a48ecc05..539c28643faa 100644 --- a/python/ray/autoscaler/aws/development-example.yaml +++ b/python/ray/autoscaler/aws/development-example.yaml @@ -94,6 +94,7 @@ setup_commands: - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc # Build Ray. - git clone https://github.com/ray-project/ray || true + - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 - cd ray/python; pip install -e . --verbose diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index c8ebe9dc31c2..7399450aeedb 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -113,9 +113,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/aws/example-gpu-docker.yaml b/python/ray/autoscaler/aws/example-gpu-docker.yaml index 37c0323fc757..79fdc055b091 100644 --- a/python/ray/autoscaler/aws/example-gpu-docker.yaml +++ b/python/ray/autoscaler/aws/example-gpu-docker.yaml @@ -105,9 +105,9 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 9a89261be7de..faaef8c6a153 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -423,6 +423,8 @@ def rsync(config_file, source, target, override_cluster_name, down): override_cluster_name: set the name of the cluster down: whether we're syncing remote -> local """ + assert bool(source) == bool(target), ( + "Must either provide both or neither source and target.") config = yaml.load(open(config_file).read()) if override_cluster_name is not None: @@ -448,7 +450,12 @@ def rsync(config_file, source, target, override_cluster_name, down): rsync = updater.rsync_down else: rsync = updater.rsync_up - rsync(source, target, check_error=False) + + if source and target: + rsync(source, target, check_error=False) + else: + updater.sync_file_mounts(rsync) + finally: provider.cleanup() diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index 9575691158c1..4ab2093dd865 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -127,9 +127,9 @@ setup_commands: && echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.profile # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. diff --git a/python/ray/autoscaler/gcp/example-gpu-docker.yaml b/python/ray/autoscaler/gcp/example-gpu-docker.yaml index 43b9d867b5b5..75e0497094cb 100644 --- a/python/ray/autoscaler/gcp/example-gpu-docker.yaml +++ b/python/ray/autoscaler/gcp/example-gpu-docker.yaml @@ -140,9 +140,9 @@ setup_commands: # - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 9fff0c767467..c86750fe399d 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -183,25 +183,9 @@ def wait_for_ssh(self, deadline): return False - def do_update(self): - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) - - deadline = time.time() + NODE_START_WAIT_S - self.set_ssh_ip_if_required() - - # Wait for SSH access - with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): - ssh_ok = self.wait_for_ssh(deadline) - assert ssh_ok, "Unable to SSH to node" - + def sync_file_mounts(self, sync_cmd): # Rsync file mounts - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "syncing-files"}) for remote_path, local_path in self.file_mounts.items(): - logger.info("NodeUpdater: " - "{}: Syncing {} to {}...".format( - self.node_id, local_path, remote_path)) assert os.path.exists(local_path), local_path if os.path.isdir(local_path): if not local_path.endswith("/"): @@ -217,7 +201,23 @@ def do_update(self): "mkdir -p {}".format(os.path.dirname(remote_path)), redirect=redirect, ) - self.rsync_up(local_path, remote_path, redirect=redirect) + sync_cmd(local_path, remote_path, redirect=redirect) + + def do_update(self): + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) + + deadline = time.time() + NODE_START_WAIT_S + self.set_ssh_ip_if_required() + + # Wait for SSH access + with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): + ssh_ok = self.wait_for_ssh(deadline) + assert ssh_ok, "Unable to SSH to node" + + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "syncing-files"}) + self.sync_file_mounts(self.rsync_up) # Run init commands self.provider.set_node_tags(self.node_id, @@ -236,6 +236,9 @@ def do_update(self): self.ssh_cmd(cmd, redirect=redirect) def rsync_up(self, source, target, redirect=None, check_error=True): + logger.info("NodeUpdater: " + "{}: Syncing {} to {}...".format(self.node_id, source, + target)) self.set_ssh_ip_if_required() self.get_caller(check_error)( [ @@ -247,6 +250,9 @@ def rsync_up(self, source, target, redirect=None, check_error=True): stderr=redirect or sys.stderr) def rsync_down(self, source, target, redirect=None, check_error=True): + logger.info("NodeUpdater: " + "{}: Syncing {} from {}...".format(self.node_id, source, + target)) self.set_ssh_ip_if_required() self.get_caller(check_error)( [ diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 425ff2d932fc..cb6438d0f2d5 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -2,14 +2,11 @@ from __future__ import division from __future__ import print_function -from .features import ( - flush_redis_unsafe, flush_task_and_object_metadata_unsafe, - flush_finished_tasks_unsafe, flush_evicted_objects_unsafe, - _flush_finished_tasks_unsafe_shard, _flush_evicted_objects_unsafe_shard) from .gcs_flush_policy import (set_flushing_policy, GcsFlushPolicy, SimpleGcsFlushPolicy) from .named_actors import get_actor, register_actor from .api import get, wait +from .dynamic_resources import set_resource def TensorFlowVariables(*args, **kwargs): @@ -19,10 +16,7 @@ def TensorFlowVariables(*args, **kwargs): __all__ = [ - "TensorFlowVariables", "flush_redis_unsafe", - "flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe", - "flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard", - "_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor", - "get", "wait", "set_flushing_policy", "GcsFlushPolicy", - "SimpleGcsFlushPolicy" + "TensorFlowVariables", "get_actor", "register_actor", "get", "wait", + "set_flushing_policy", "GcsFlushPolicy", "SimpleGcsFlushPolicy", + "set_resource" ] diff --git a/python/ray/experimental/dynamic_resources.py b/python/ray/experimental/dynamic_resources.py new file mode 100644 index 000000000000..34b2b99e65a2 --- /dev/null +++ b/python/ray/experimental/dynamic_resources.py @@ -0,0 +1,35 @@ +import ray + + +def set_resource(resource_name, capacity, client_id=None): + """ Set a resource to a specified capacity. + + This creates, updates or deletes a custom resource for a target clientId. + If the resource already exists, it's capacity is updated to the new value. + If the capacity is set to 0, the resource is deleted. + If ClientID is not specified or set to None, + the resource is created on the local client where the actor is running. + + Args: + resource_name (str): Name of the resource to be created + capacity (int): Capacity of the new resource. Resource is deleted if + capacity is 0. + client_id (str): The ClientId of the node where the resource is to be + set. + + Returns: + None + + Raises: + ValueError: This exception is raised when a non-negative capacity is + specified. + """ + if client_id is not None: + client_id_obj = ray.ClientID(ray.utils.hex_to_binary(client_id)) + else: + client_id_obj = ray.ClientID.nil() + if (capacity < 0) or (capacity != int(capacity)): + raise ValueError( + "Capacity {} must be a non-negative integer.".format(capacity)) + return ray.worker.global_worker.raylet_client.set_resource( + resource_name, capacity, client_id_obj) diff --git a/python/ray/experimental/features.py b/python/ray/experimental/features.py deleted file mode 100644 index 90f893f271fb..000000000000 --- a/python/ray/experimental/features.py +++ /dev/null @@ -1,186 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.utils import binary_to_hex - -OBJECT_INFO_PREFIX = b"OI:" -OBJECT_LOCATION_PREFIX = b"OL:" -TASK_PREFIX = b"TT:" - - -def flush_redis_unsafe(redis_client=None): - """This removes some non-critical state from the primary Redis shard. - - This removes the log files as well as the event log from Redis. This can - be used to try to address out-of-memory errors caused by the accumulation - of metadata in Redis. However, it will only partially address the issue as - much of the data is in the task table (and object table), which are not - flushed. - - Args: - redis_client: optional, if not provided then ray.init() must have been - called. - """ - if redis_client is None: - ray.worker.global_worker.check_connected() - redis_client = ray.worker.global_worker.redis_client - - # Delete the log files from the primary Redis shard. - keys = redis_client.keys("LOGFILE:*") - if len(keys) > 0: - num_deleted = redis_client.delete(*keys) - else: - num_deleted = 0 - print("Deleted {} log files from Redis.".format(num_deleted)) - - # Delete the event log from the primary Redis shard. - keys = redis_client.keys("event_log:*") - if len(keys) > 0: - num_deleted = redis_client.delete(*keys) - else: - num_deleted = 0 - print("Deleted {} event logs from Redis.".format(num_deleted)) - - -def flush_task_and_object_metadata_unsafe(): - """This removes some critical state from the Redis shards. - - In a multitenant environment, this will flush metadata for all jobs, which - may be undesirable. - - This removes all of the object and task metadata. This can be used to try - to address out-of-memory errors caused by the accumulation of metadata in - Redis. However, after running this command, fault tolerance will most - likely not work. - """ - ray.worker.global_worker.check_connected() - - def flush_shard(redis_client): - # Flush the task table. Note that this also flushes the driver tasks - # which may be undesirable. - num_task_keys_deleted = 0 - for key in redis_client.scan_iter(match=TASK_PREFIX + b"*"): - num_task_keys_deleted += redis_client.delete(key) - print("Deleted {} task keys from Redis.".format(num_task_keys_deleted)) - - # Flush the object information. - num_object_keys_deleted = 0 - for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): - num_object_keys_deleted += redis_client.delete(key) - print("Deleted {} object info keys from Redis.".format( - num_object_keys_deleted)) - - # Flush the object locations. - num_object_location_keys_deleted = 0 - for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"): - num_object_location_keys_deleted += redis_client.delete(key) - print("Deleted {} object location keys from Redis.".format( - num_object_location_keys_deleted)) - - # Loop over the shards and flush all of them. - for redis_client in ray.worker.global_state.redis_clients: - flush_shard(redis_client) - - -def _task_table_shard(shard_index): - redis_client = ray.global_state.redis_clients[shard_index] - task_table_keys = redis_client.keys(TASK_PREFIX + b"*") - results = {} - for key in task_table_keys: - task_id_binary = key[len(TASK_PREFIX):] - results[binary_to_hex(task_id_binary)] = ray.global_state._task_table( - ray.TaskID(task_id_binary)) - - return results - - -def _object_table_shard(shard_index): - redis_client = ray.global_state.redis_clients[shard_index] - object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*") - results = {} - for key in object_table_keys: - object_id_binary = key[len(OBJECT_LOCATION_PREFIX):] - results[binary_to_hex(object_id_binary)] = ( - ray.global_state._object_table(ray.ObjectID(object_id_binary))) - - return results - - -def _flush_finished_tasks_unsafe_shard(shard_index): - ray.worker.global_worker.check_connected() - - redis_client = ray.global_state.redis_clients[shard_index] - tasks = _task_table_shard(shard_index) - - keys_to_delete = [] - for task_id, task_info in tasks.items(): - if task_info["State"] == ray.experimental.state.TASK_STATUS_DONE: - keys_to_delete.append(TASK_PREFIX + - ray.utils.hex_to_binary(task_id)) - - num_task_keys_deleted = 0 - if len(keys_to_delete) > 0: - num_task_keys_deleted = redis_client.execute_command( - "del", *keys_to_delete) - - print("Deleted {} finished tasks from Redis shard." - .format(num_task_keys_deleted)) - - -def _flush_evicted_objects_unsafe_shard(shard_index): - ray.worker.global_worker.check_connected() - - redis_client = ray.global_state.redis_clients[shard_index] - objects = _object_table_shard(shard_index) - - keys_to_delete = [] - for object_id, object_info in objects.items(): - if object_info["ManagerIDs"] == []: - keys_to_delete.append(OBJECT_LOCATION_PREFIX + - ray.utils.hex_to_binary(object_id)) - keys_to_delete.append(OBJECT_INFO_PREFIX + - ray.utils.hex_to_binary(object_id)) - - num_object_keys_deleted = 0 - if len(keys_to_delete) > 0: - num_object_keys_deleted = redis_client.execute_command( - "del", *keys_to_delete) - - print("Deleted {} keys for evicted objects from Redis." - .format(num_object_keys_deleted)) - - -def flush_finished_tasks_unsafe(): - """This removes some critical state from the Redis shards. - - In a multitenant environment, this will flush metadata for all jobs, which - may be undesirable. - - This removes all of the metadata for finished tasks. This can be used to - try to address out-of-memory errors caused by the accumulation of metadata - in Redis. However, after running this command, fault tolerance will most - likely not work. - """ - ray.worker.global_worker.check_connected() - - for shard_index in range(len(ray.global_state.redis_clients)): - _flush_finished_tasks_unsafe_shard(shard_index) - - -def flush_evicted_objects_unsafe(): - """This removes some critical state from the Redis shards. - - In a multitenant environment, this will flush metadata for all jobs, which - may be undesirable. - - This removes all of the metadata for objects that have been evicted. This - can be used to try to address out-of-memory errors caused by the - accumulation of metadata in Redis. However, after running this command, - fault tolerance will most likely not work. - """ - ray.worker.global_worker.check_connected() - - for shard_index in range(len(ray.global_state.redis_clients)): - _flush_evicted_objects_unsafe_shard(shard_index) diff --git a/python/ray/experimental/tf_utils.py b/python/ray/experimental/tf_utils.py index d2f1b259961c..900cc948b066 100644 --- a/python/ray/experimental/tf_utils.py +++ b/python/ray/experimental/tf_utils.py @@ -5,7 +5,9 @@ from collections import deque, OrderedDict import numpy as np -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def unflatten(vector, shapes): diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index e4a172fc1e71..4914c9f87050 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -342,7 +342,7 @@ def export(self, remote_function): # and export it later. self._functions_to_export.append(remote_function) return - if self._worker.mode != ray.worker.SCRIPT_MODE: + if self._worker.mode == ray.worker.LOCAL_MODE: # Don't need to export if the worker is not a driver. return self._do_export(remote_function) diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 3b6463fc9ea6..bdb4316fcc4e 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -81,15 +81,9 @@ cdef extern from "ray/status.h" namespace "ray::StatusCode" nogil: cdef extern from "ray/id.h" namespace "ray" nogil: - const CTaskID FinishTaskId(const CTaskID &task_id) - const CObjectID ComputeReturnId(const CTaskID &task_id, - int64_t return_index) - const CObjectID ComputePutId(const CTaskID &task_id, int64_t put_index) - const CTaskID ComputeTaskId(const CObjectID &object_id) const CTaskID GenerateTaskId(const CDriverID &driver_id, const CTaskID &parent_task_id, int parent_task_counter) - int64_t ComputeObjectIndex(const CObjectID &object_id) cdef extern from "ray/gcs/format/gcs_generated.h" nogil: diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index be74b06e5729..1b4c5e3cd037 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -72,6 +72,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CActorCheckpointID &checkpoint_id) CRayStatus NotifyActorResumedFromCheckpoint( const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) + CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id) CLanguage GetLanguage() const CClientID GetClientID() const CDriverID GetDriverID() const diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index a607b2a86419..fbe793cc023b 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -1,12 +1,35 @@ from libcpp cimport bool as c_bool from libcpp.string cimport string as c_string -from libc.stdint cimport uint8_t +from libc.stdint cimport uint8_t, int64_t cdef extern from "ray/id.h" namespace "ray" nogil: - cdef cppclass CUniqueID "ray::UniqueID": + cdef cppclass CBaseID[T]: + @staticmethod + T from_random() + + @staticmethod + T from_binary(const c_string &binary) + + @staticmethod + const T nil() + + @staticmethod + size_t size() + + size_t hash() const + c_bool is_nil() const + c_bool operator==(const CBaseID &rhs) const + c_bool operator!=(const CBaseID &rhs) const + const uint8_t *data() const; + + c_string binary() const; + c_string hex() const; + + cdef cppclass CUniqueID "ray::UniqueID"(CBaseID): CUniqueID() - CUniqueID(const c_string &binary) - CUniqueID(const CUniqueID &from_id) + + @staticmethod + size_t size() @staticmethod CUniqueID from_random() @@ -17,15 +40,8 @@ cdef extern from "ray/id.h" namespace "ray" nogil: @staticmethod const CUniqueID nil() - size_t hash() const - c_bool is_nil() const - c_bool operator==(const CUniqueID& rhs) const - c_bool operator!=(const CUniqueID& rhs) const - const uint8_t *data() const - uint8_t *mutable_data() - size_t size() const - c_string binary() const - c_string hex() const + @staticmethod + size_t size() cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): @@ -67,16 +83,40 @@ cdef extern from "ray/id.h" namespace "ray" nogil: @staticmethod CDriverID from_binary(const c_string &binary) - cdef cppclass CTaskID "ray::TaskID"(CUniqueID): + cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): @staticmethod CTaskID from_binary(const c_string &binary) - cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): + @staticmethod + const CTaskID nil() + + @staticmethod + size_t size() + + cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): @staticmethod CObjectID from_binary(const c_string &binary) + @staticmethod + const CObjectID nil() + + @staticmethod + CObjectID for_put(const CTaskID &task_id, int64_t index); + + @staticmethod + CObjectID for_task_return(const CTaskID &task_id, int64_t index); + + @staticmethod + size_t size() + + c_bool is_put() + + int64_t object_index() const + + CTaskID task_id() const + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): @staticmethod diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index c96668f2bf07..b9773d56fb20 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -6,10 +6,8 @@ See https://github.com/ray-project/ray/issues/3721. # WARNING: Any additional ID types defined in this file must be added to the # _ID_TYPES list at the bottom of this file. -from ray.includes.common cimport ( - ComputePutId, - ComputeTaskId, -) +import os + from ray.includes.unique_ids cimport ( CActorCheckpointID, CActorClassID, @@ -28,12 +26,12 @@ from ray.includes.unique_ids cimport ( from ray.utils import decode -def check_id(b): +def check_id(b, size=kUniqueIDSize): if not isinstance(b, bytes): raise TypeError("Unsupported type: " + str(type(b))) - if len(b) != kUniqueIDSize: + if len(b) != size: raise ValueError("ID string needs to have length " + - str(kUniqueIDSize)) + str(size)) cdef extern from "ray/constants.h" nogil: @@ -41,28 +39,27 @@ cdef extern from "ray/constants.h" nogil: cdef int64_t kMaxTaskPuts -cdef class UniqueID: - cdef CUniqueID data +cdef class BaseID: - def __init__(self, id): - check_id(id) - self.data = CUniqueID.from_binary(id) + # To avoid the error of "Python int too large to convert to C ssize_t", + # here `cdef size_t` is required. + cdef size_t hash(self): + pass - @classmethod - def from_binary(cls, id_bytes): - if not isinstance(id_bytes, bytes): - raise TypeError("Expect bytes, got " + str(type(id_bytes))) - return cls(id_bytes) + def binary(self): + pass - @classmethod - def nil(cls): - return cls(CUniqueID.nil().binary()) + def size(self): + pass - def __hash__(self): - return self.data.hash() + def hex(self): + pass def is_nil(self): - return self.data.is_nil() + pass + + def __hash__(self): + return self.hash() def __eq__(self, other): return type(self) == type(other) and self.binary() == other.binary() @@ -70,18 +67,9 @@ cdef class UniqueID: def __ne__(self, other): return self.binary() != other.binary() - def size(self): - return self.data.size() - - def binary(self): - return self.data.binary() - def __bytes__(self): return self.binary() - def hex(self): - return decode(self.data.hex()) - def __hex__(self): return self.hex() @@ -98,11 +86,52 @@ cdef class UniqueID: # NOTE: The hash function used here must match the one in # GetRedisContext in src/ray/gcs/tables.h. Changes to the # hash function should only be made through std::hash in - # src/common/common.h + # src/common/common.h. + # Do not use __hash__ that returns signed uint64_t, which + # is different from std::hash in c++ code. + return self.hash() + + +cdef class UniqueID(BaseID): + cdef CUniqueID data + + def __init__(self, id): + check_id(id) + self.data = CUniqueID.from_binary(id) + + @classmethod + def from_binary(cls, id_bytes): + if not isinstance(id_bytes, bytes): + raise TypeError("Expect bytes, got " + str(type(id_bytes))) + return cls(id_bytes) + + @classmethod + def nil(cls): + return cls(CUniqueID.nil().binary()) + + + @classmethod + def from_random(cls): + return cls(os.urandom(CUniqueID.size())) + + def size(self): + return CUniqueID.size() + + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): return self.data.hash() -cdef class ObjectID(UniqueID): +cdef class ObjectID(BaseID): + cdef CObjectID data def __init__(self, id): check_id(id) @@ -111,16 +140,67 @@ cdef class ObjectID(UniqueID): cdef CObjectID native(self): return self.data + def size(self): + return CObjectID.size() -cdef class TaskID(UniqueID): + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): + return self.data.hash() + + @classmethod + def nil(cls): + return cls(CObjectID.nil().binary()) + + @classmethod + def from_random(cls): + return cls(os.urandom(CObjectID.size())) + + +cdef class TaskID(BaseID): + cdef CTaskID data def __init__(self, id): - check_id(id) + check_id(id, CTaskID.size()) self.data = CTaskID.from_binary(id) cdef CTaskID native(self): return self.data + def size(self): + return CTaskID.size() + + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): + return self.data.hash() + + @classmethod + def nil(cls): + return cls(CTaskID.nil().binary()) + + @classmethod + def size(cla): + return CTaskID.size() + + @classmethod + def from_random(cls): + return cls(os.urandom(CTaskID.size())) + cdef class ClientID(UniqueID): diff --git a/python/ray/monitor.py b/python/ray/monitor.py index ded86611e88c..09a154d7b548 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -16,8 +16,8 @@ import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants -from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary, - setup_logger) +from ray.utils import (binary_to_hex, binary_to_object_id, binary_to_task_id, + hex_to_binary, setup_logger) logger = logging.getLogger(__name__) @@ -37,8 +37,7 @@ class Monitor(object): def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. - self.state = ray.experimental.state.GlobalState() - self.state._initialize_global_state( + ray.state.state._initialize_global_state( args.redis_address, redis_password=redis_password) self.redis = ray.services.create_redis_client( redis_address, password=redis_password) @@ -149,7 +148,7 @@ def _xray_clean_up_entries_for_driver(self, driver_id): xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) - task_table_objects = self.state.task_table() + task_table_objects = ray.tasks() driver_id_hex = binary_to_hex(driver_id) driver_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): @@ -161,7 +160,7 @@ def _xray_clean_up_entries_for_driver(self, driver_id): driver_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. - object_table_objects = self.state.object_table() + object_table_objects = ray.objects() driver_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray._raylet.compute_task_id(object_id).binary() @@ -169,11 +168,15 @@ def _xray_clean_up_entries_for_driver(self, driver_id): driver_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): - return binary_to_object_id(id_bin).redis_shard_hash() % len( - self.state.redis_clients) + if len(id_bin) == ray.TaskID.size(): + return binary_to_task_id(id_bin).redis_shard_hash() % len( + ray.state.state.redis_clients) + else: + return binary_to_object_id(id_bin).redis_shard_hash() % len( + ray.state.state.redis_clients) # Form the redis keys to delete. - sharded_keys = [[] for _ in range(len(self.state.redis_clients))] + sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] for task_id_bin in driver_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) @@ -186,7 +189,7 @@ def to_shard_index(id_bin): keys = sharded_keys[shard_index] if len(keys) == 0: continue - redis = self.state.redis_clients[shard_index] + redis = ray.state.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) logger.info("Monitor: " "Removed {} dead redis entries of the " @@ -252,7 +255,7 @@ def process_messages(self, max_messages=10000): message_handler(channel, data) def update_raylet_map(self): - all_raylet_nodes = self.state.client_table() + all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: client_id = (raylet_info.get("DBClientID") diff --git a/python/ray/node.py b/python/ray/node.py index 733f21c9d728..85510147a35f 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -435,6 +435,7 @@ def start_raylet(self, use_valgrind=False, use_profiler=False): self._plasma_store_socket_name, self._ray_params.worker_path, self._temp_dir, + self._session_dir, self._ray_params.num_cpus, self._ray_params.num_gpus, self._ray_params.resources, diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 3bc3fc2bd92e..44d2777a2900 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -43,6 +43,12 @@ class RemoteFunction(object): return the resulting ObjectIDs. For an example, see "test_decorated_function" in "python/ray/tests/test_basic.py". _function_signature: The function signature. + _last_driver_id_exported_for: The ID of the driver ID of the last Ray + session during which this remote function definition was exported. + This is an imperfect mechanism used to determine if we need to + export the remote function again. It is imperfect in the sense that + the actor class definition could be exported multiple times by + different workers. """ def __init__(self, function, num_cpus, num_gpus, resources, @@ -66,11 +72,7 @@ def __init__(self, function, num_cpus, num_gpus, resources, self._function_signature = ray.signature.extract_signature( self._function) - # Export the function. - worker = ray.worker.get_global_worker() - # In which session this function was exported last time. - self._last_export_session = worker._session_index - worker.function_actor_manager.export(self) + self._last_driver_id_exported_for = None def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " @@ -109,10 +111,11 @@ def _remote(self, worker = ray.worker.get_global_worker() worker.check_connected() - if self._last_export_session < worker._session_index: + if (self._last_driver_id_exported_for is None + or self._last_driver_id_exported_for != worker.task_driver_id): # If this function was exported in a previous session, we need to # export this function again, because current GCS doesn't have it. - self._last_export_session = worker._session_index + self._last_driver_id_exported_for = worker.task_driver_id worker.function_actor_manager.export(self) kwargs = {} if kwargs is None else kwargs diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index 613199cf795f..92844e485ff3 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -3,6 +3,7 @@ from __future__ import print_function import logging +import sys # Note: do not introduce unnecessary library dependencies here, e.g. gym. # This file is imported from the tune module in order to register RLlib agents. @@ -10,12 +11,14 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.sample_batch import SampleBatch def _setup_logger(): @@ -28,6 +31,11 @@ def _setup_logger(): logger.addHandler(handler) logger.propagate = False + if sys.version_info[0] < 3: + logger.warn( + "RLlib Python 2 support is deprecated, and will be removed " + "in a future release.") + def _register_all(): @@ -43,7 +51,9 @@ def _register_all(): _register_all() __all__ = [ + "Policy", "PolicyGraph", + "TFPolicy", "TFPolicyGraph", "PolicyEvaluator", "SampleBatch", diff --git a/python/ray/rllib/agents/a3c/__init__.py b/python/ray/rllib/agents/a3c/__init__.py index 9c8205389ea2..4a8480eab695 100644 --- a/python/ray/rllib/agents/a3c/__init__.py +++ b/python/ray/rllib/agents/a3c/__init__.py @@ -1,9 +1,9 @@ from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG from ray.rllib.agents.a3c.a2c import A2CTrainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -A2CAgent = renamed_class(A2CTrainer) -A3CAgent = renamed_class(A3CTrainer) +A2CAgent = renamed_agent(A2CTrainer) +A3CAgent = renamed_agent(A3CTrainer) __all__ = [ "A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG" diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 836d9f074999..56d7a09daa0f 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -4,7 +4,7 @@ import time -from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph +from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer from ray.rllib.utils.annotations import override @@ -43,16 +43,16 @@ class A3CTrainer(Trainer): _name = "A3C" _default_config = DEFAULT_CONFIG - _policy_graph = A3CPolicyGraph + _policy = A3CTFPolicy @override(Trainer) def _init(self, config, env_creator): if config["use_pytorch"]: - from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ - A3CTorchPolicyGraph - policy_cls = A3CTorchPolicyGraph + from ray.rllib.agents.a3c.a3c_torch_policy import \ + A3CTorchPolicy + policy_cls = A3CTorchPolicy else: - policy_cls = self._policy_graph + policy_cls = self._policy if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy.py similarity index 91% rename from python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py rename to python/ray/rllib/agents/a3c/a3c_tf_policy.py index d4e140543e31..eb5becceaa71 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy.py @@ -1,24 +1,26 @@ -"""Note: Keep in sync with changes to VTracePolicyGraph.""" +"""Note: Keep in sync with changes to VTraceTFPolicy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow as tf import gym import ray from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ +from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class A3CLoss(object): @@ -45,13 +47,13 @@ def __init__(self, class A3CPostprocessing(object): """Adds the VF preds and advantages fields to the trajectory.""" - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), + TFPolicy.extra_compute_action_fetches(self), **{SampleBatch.VF_PREDS: self.vf}) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -71,7 +73,7 @@ def postprocess_trajectory(self, self.config["lambda"]) -class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph): +class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) self.config = config @@ -112,7 +114,7 @@ def __init__(self, observation_space, action_space, config): self.vf, self.config["vf_loss_coeff"], self.config["entropy_coeff"]) - # Initialize TFPolicyGraph + # Initialize TFPolicy loss_in = [ (SampleBatch.CUR_OBS, self.observations), (SampleBatch.ACTIONS, actions), @@ -123,7 +125,7 @@ def __init__(self, observation_space, action_space, config): ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -155,18 +157,18 @@ def __init__(self, observation_space, action_space, config): self.sess.run(tf.global_variables_initializer()) - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return self.model.state_init - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): grads = tf.gradients(loss, self.var_list) self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) clipped_grads = list(zip(self.grads, self.var_list)) return clipped_grads - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return self.stats_fetches diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy.py b/python/ray/rllib/agents/a3c/a3c_torch_policy.py new file mode 100644 index 000000000000..6ccf6c48d35f --- /dev/null +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy.py @@ -0,0 +1,90 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn.functional as F +from torch import nn + +import ray +from ray.rllib.evaluation.postprocessing import compute_advantages, \ + Postprocessing +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy + + +def actor_critic_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + dist = policy.dist_class(logits) + log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) + policy.entropy = dist.entropy().mean() + policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( + log_probs.reshape(-1)) + policy.value_err = F.mse_loss( + values.reshape(-1), batch_tensors[Postprocessing.VALUE_TARGETS]) + overall_err = sum([ + policy.pi_err, + policy.config["vf_loss_coeff"] * policy.value_err, + -policy.config["entropy_coeff"] * policy.entropy, + ]) + return overall_err + + +def loss_and_entropy_stats(policy, batch_tensors): + return { + "policy_entropy": policy.entropy.item(), + "policy_loss": policy.pi_err.item(), + "vf_loss": policy.value_err.item(), + } + + +def add_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + +def model_value_predictions(policy, model_out): + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} + + +def apply_grad_clipping(policy): + info = {} + if policy.config["grad_clip"]: + total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(), + policy.config["grad_clip"]) + info["grad_gnorm"] = total_norm + return info + + +def torch_optimizer(policy, config): + return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) + + +class ValueNetworkMixin(object): + def _value(self, obs): + with self.lock: + obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) + _, _, vf, _ = self.model({"obs": obs}, []) + return vf.detach().cpu().numpy().squeeze() + + +A3CTorchPolicy = build_torch_policy( + name="A3CTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=loss_and_entropy_stats, + postprocess_fn=add_advantages, + extra_action_out_fn=model_value_predictions, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=torch_optimizer, + mixins=[ValueNetworkMixin]) diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py deleted file mode 100644 index d35aabe0d667..000000000000 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -import torch.nn.functional as F -from torch import nn - -import ray -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.evaluation.postprocessing import compute_advantages, \ - Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.utils.annotations import override - - -class A3CLoss(nn.Module): - def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01): - nn.Module.__init__(self) - self.dist_class = dist_class - self.vf_loss_coeff = vf_loss_coeff - self.entropy_coeff = entropy_coeff - - def forward(self, policy_model, observations, actions, advantages, - value_targets): - logits, _, values, _ = policy_model({ - SampleBatch.CUR_OBS: observations - }, []) - dist = self.dist_class(logits) - log_probs = dist.logp(actions) - self.entropy = dist.entropy().mean() - self.pi_err = -advantages.dot(log_probs.reshape(-1)) - self.value_err = F.mse_loss(values.reshape(-1), value_targets) - overall_err = sum([ - self.pi_err, - self.vf_loss_coeff * self.value_err, - -self.entropy_coeff * self.entropy, - ]) - - return overall_err - - -class A3CPostprocessing(object): - """Adds the VF preds and advantages fields to the trajectory.""" - - @override(TorchPolicyGraph) - def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1]) - return compute_advantages(sample_batch, last_r, self.config["gamma"], - self.config["lambda"]) - - -class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph): - """A simple, non-recurrent PyTorch policy example.""" - - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], torch=True) - model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = A3CLoss(dist_class, self.config["vf_loss_coeff"], - self.config["entropy_coeff"]) - TorchPolicyGraph.__init__( - self, - obs_space, - action_space, - model, - loss, - loss_inputs=[ - SampleBatch.CUR_OBS, SampleBatch.ACTIONS, - Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS - ], - action_distribution_cls=dist_class) - - @override(TorchPolicyGraph) - def optimizer(self): - return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) - - @override(TorchPolicyGraph) - def extra_grad_process(self): - info = {} - if self.config["grad_clip"]: - total_norm = nn.utils.clip_grad_norm_(self._model.parameters(), - self.config["grad_clip"]) - info["grad_gnorm"] = total_norm - return info - - @override(TorchPolicyGraph) - def extra_grad_info(self): - return { - "policy_entropy": self._loss.entropy.item(), - "policy_loss": self._loss.pi_err.item(), - "vf_loss": self._loss.value_err.item() - } - - def _value(self, obs): - with self.lock: - obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self._model({"obs": obs}, []) - return vf.detach().cpu().numpy().squeeze() diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 5b0ecf268fe7..17da952ddedf 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -3,6 +3,6 @@ from __future__ import print_function from ray.rllib.agents.trainer import Trainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -Agent = renamed_class(Trainer) +Agent = renamed_agent(Trainer) diff --git a/python/ray/rllib/agents/ars/__init__.py b/python/ray/rllib/agents/ars/__init__.py index a1120ff8ce31..0681efe7ab37 100644 --- a/python/ray/rllib/agents/ars/__init__.py +++ b/python/ray/rllib/agents/ars/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG) -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ARSAgent = renamed_class(ARSTrainer) +ARSAgent = renamed_agent(ARSTrainer) __all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 65738a620b30..4330f0d90db0 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -17,7 +17,7 @@ from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies from ray.rllib.agents.ars import utils -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.memory import ray_get_and_free from ray.rllib.utils import FilterManager diff --git a/python/ray/rllib/agents/ars/policies.py b/python/ray/rllib/agents/ars/policies.py index fe82be5b65dd..7fdb54b99cd8 100644 --- a/python/ray/rllib/agents/ars/policies.py +++ b/python/ray/rllib/agents/ars/policies.py @@ -7,13 +7,15 @@ import gym import numpy as np -import tensorflow as tf import ray import ray.experimental.tf_utils from ray.rllib.evaluation.sampler import _unbatch_tuple_actions from ray.rllib.utils.filter import get_filter from ray.rllib.models import ModelCatalog +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0): diff --git a/python/ray/rllib/agents/ars/utils.py b/python/ray/rllib/agents/ars/utils.py index 1575e46c3837..518fd3d00634 100644 --- a/python/ray/rllib/agents/ars/utils.py +++ b/python/ray/rllib/agents/ars/utils.py @@ -6,7 +6,9 @@ from __future__ import print_function import numpy as np -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def compute_ranks(x): diff --git a/python/ray/rllib/agents/ddpg/__init__.py b/python/ray/rllib/agents/ddpg/__init__.py index 9b90ca842ae5..3d681b8356c9 100644 --- a/python/ray/rllib/agents/ddpg/__init__.py +++ b/python/ray/rllib/agents/ddpg/__init__.py @@ -5,10 +5,10 @@ from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG from ray.rllib.agents.ddpg.td3 import TD3Trainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ApexDDPGAgent = renamed_class(ApexDDPGTrainer) -DDPGAgent = renamed_class(DDPGTrainer) +ApexDDPGAgent = renamed_agent(ApexDDPGTrainer) +DDPGAgent = renamed_agent(DDPGTrainer) __all__ = [ "DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer", diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index 7a140beeea24..66d3810e5e93 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -4,7 +4,7 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import DQNTrainer -from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph +from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule @@ -163,7 +163,7 @@ class DDPGTrainer(DQNTrainer): """DDPG implementation in TensorFlow.""" _name = "DDPG" _default_config = DEFAULT_CONFIG - _policy_graph = DDPGPolicyGraph + _policy = DDPGTFPolicy @override(DQNTrainer) def _train(self): diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy.py similarity index 92% rename from python/ray/rllib/agents/ddpg/ddpg_policy_graph.py rename to python/ray/rllib/agents/ddpg/ddpg_policy.py index 9304cbe0b598..b80cfce4cdaa 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy.py @@ -4,20 +4,21 @@ from gym.spaces import Box import numpy as np -import tensorflow as tf -import tensorflow.contrib.layers as layers import ray import ray.experimental.tf_utils -from ray.rllib.agents.dqn.dqn_policy_graph import ( - _huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn) -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.agents.dqn.dqn_policy import (_huber_loss, _minimize_and_clip, + _scope_vars, _postprocess_dqn) +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() ACTION_SCOPE = "action" POLICY_SCOPE = "policy" @@ -34,7 +35,7 @@ class DDPGPostprocessing(object): """Implements n-step learning and param noise adjustments.""" - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -67,7 +68,7 @@ def postprocess_trajectory(self, return _postprocess_dqn(self, sample_batch) -class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph): +class DDPGTFPolicy(DDPGPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config) if not isinstance(action_space, Box): @@ -165,8 +166,9 @@ def __init__(self, observation_space, action_space, config): stddev=self.config["target_noise"]), -target_noise_clip, target_noise_clip) policy_tp1_smoothed = tf.clip_by_value( - policy_tp1 + clipped_normal_sample, action_space.low, - action_space.high) + policy_tp1 + clipped_normal_sample, + action_space.low * tf.ones_like(policy_tp1), + action_space.high * tf.ones_like(policy_tp1)) else: # no smoothing, just use deterministic actions policy_tp1_smoothed = policy_tp1 @@ -279,7 +281,7 @@ def __init__(self, observation_space, action_space, config): self.critic_loss = self.twin_q_model.custom_loss( self.critic_loss, input_dict) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -299,12 +301,12 @@ def __init__(self, observation_space, action_space, config): # Hard initial update self.update_target(tau=1.0) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): # we don't use this because we have two separate optimisers return None - @override(TFPolicyGraph) + @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): # for policy gradient, update policy net one time v.s. # update critic net `policy_delay` time(s) @@ -325,7 +327,7 @@ def make_apply_op(): with tf.control_dependencies([tf.assign_add(self.global_step, 1)]): return tf.group(actor_op, critic_op) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): if self.config["grad_norm_clipping"] is not None: actor_grads_and_vars = _minimize_and_clip( @@ -358,7 +360,7 @@ def gradients(self, optimizer, loss): + self._critic_grads_and_vars return grads_and_vars - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_feed_dict(self): return { # FIXME: what about turning off exploration? Isn't that a good @@ -368,31 +370,31 @@ def extra_compute_action_feed_dict(self): self.pure_exploration_phase: self.cur_pure_exploration_phase, } - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return { "td_error": self.td_error, LEARNER_STATS_KEY: self.stats, } - @override(TFPolicyGraph) + @override(TFPolicy) def get_weights(self): return self.variables.get_weights() - @override(TFPolicyGraph) + @override(TFPolicy) def set_weights(self, weights): self.variables.set_weights(weights) - @override(PolicyGraph) + @override(Policy) def get_state(self): return [ - TFPolicyGraph.get_state(self), self.cur_noise_scale, + TFPolicy.get_state(self), self.cur_noise_scale, self.cur_pure_exploration_phase ] - @override(PolicyGraph) + @override(Policy) def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) + TFPolicy.set_state(self, state[0]) self.set_epsilon(state[1]) self.set_pure_exploration_phase(state[2]) @@ -409,10 +411,8 @@ def _build_q_network(self, obs, obs_space, action_space, actions): activation = getattr(tf.nn, self.config["critic_hidden_activation"]) for hidden in self.config["critic_hiddens"]: - q_out = layers.fully_connected( - q_out, num_outputs=hidden, activation_fn=activation) - q_values = layers.fully_connected( - q_out, num_outputs=1, activation_fn=None) + q_out = tf.layers.dense(q_out, units=hidden, activation=activation) + q_values = tf.layers.dense(q_out, units=1, activation=None) return q_values, q_model @@ -428,16 +428,19 @@ def _build_policy_network(self, obs, obs_space, action_space): action_out = obs activation = getattr(tf.nn, self.config["actor_hidden_activation"]) - normalizer_fn = layers.layer_norm if self.config["parameter_noise"] \ - else None for hidden in self.config["actor_hiddens"]: - action_out = layers.fully_connected( - action_out, - num_outputs=hidden, - activation_fn=activation, - normalizer_fn=normalizer_fn) - action_out = layers.fully_connected( - action_out, num_outputs=self.dim_actions, activation_fn=None) + if self.config["parameter_noise"]: + import tensorflow.contrib.layers as layers + action_out = layers.fully_connected( + action_out, + num_outputs=hidden, + activation_fn=activation, + normalizer_fn=layers.layer_norm) + else: + action_out = tf.layers.dense( + action_out, units=hidden, activation=activation) + action_out = tf.layers.dense( + action_out, units=self.dim_actions, activation=None) # Use sigmoid to scale to [0,1], but also double magnitude of input to # emulate behaviour of tanh activation used in DDPG and TD3 papers. @@ -468,8 +471,9 @@ def make_noisy_actions(): tf.shape(deterministic_actions), stddev=self.config["exploration_gaussian_sigma"]) stochastic_actions = tf.clip_by_value( - deterministic_actions + normal_sample, action_low, - action_high) + deterministic_actions + normal_sample, + action_low * tf.ones_like(deterministic_actions), + action_high * tf.ones_like(deterministic_actions)) elif noise_type == "ou": # add OU noise for exploration, DDPG-style zero_acts = action_low.size * [.0] @@ -489,7 +493,9 @@ def make_noisy_actions(): noise = noise_scale * base_scale \ * exploration_value * action_range stochastic_actions = tf.clip_by_value( - deterministic_actions + noise, action_low, action_high) + deterministic_actions + noise, + action_low * tf.ones_like(deterministic_actions), + action_high * tf.ones_like(deterministic_actions)) else: raise ValueError( "Unknown noise type '%s' (try 'ou' or 'gaussian')" % @@ -498,7 +504,7 @@ def make_noisy_actions(): def make_uniform_random_actions(): # pure random exploration option - uniform_random_actions = tf.random.uniform( + uniform_random_actions = tf.random_uniform( tf.shape(deterministic_actions)) # rescale uniform random actions according to action range tf_range = tf.constant(action_range[None], dtype="float32") diff --git a/python/ray/rllib/agents/dqn/__init__.py b/python/ray/rllib/agents/dqn/__init__.py index 415ceae6c1de..d3de8cb802cc 100644 --- a/python/ray/rllib/agents/dqn/__init__.py +++ b/python/ray/rllib/agents/dqn/__init__.py @@ -4,10 +4,10 @@ from ray.rllib.agents.dqn.apex import ApexTrainer from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -DQNAgent = renamed_class(DQNTrainer) -ApexAgent = renamed_class(ApexTrainer) +DQNAgent = renamed_agent(DQNTrainer) +ApexAgent = renamed_agent(ApexTrainer) __all__ = [ "DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG" diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index d8fb480cbda6..7fdb6f66b433 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -8,9 +8,9 @@ from ray import tune from ray.rllib import optimizers from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule @@ -133,7 +133,7 @@ class DQNTrainer(Trainer): _name = "DQN" _default_config = DEFAULT_CONFIG - _policy_graph = DQNPolicyGraph + _policy = DQNTFPolicy _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS @override(Trainer) @@ -197,10 +197,10 @@ def on_episode_end(info): on_episode_end) self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) + env_creator, self._policy) def create_remote_evaluators(): - return self.make_remote_evaluators(env_creator, self._policy_graph, + return self.make_remote_evaluators(env_creator, self._policy, config["num_workers"]) if config["optimizer_class"] != "AsyncReplayOptimizer": diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy.py similarity index 92% rename from python/ray/rllib/agents/dqn/dqn_policy_graph.py rename to python/ray/rllib/agents/dqn/dqn_policy.py index 6a226d237461..a1affa947a43 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy.py @@ -5,18 +5,19 @@ from gym.spaces import Discrete import numpy as np from scipy.stats import entropy -import tensorflow as tf -import tensorflow.contrib.layers as layers import ray -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog, Categorical from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" @@ -104,14 +105,14 @@ def __init__(self, class DQNPostprocessing(object): """Implements n-step learning and param noise adjustments.""" - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), **{ + TFPolicy.extra_compute_action_fetches(self), **{ "q_values": self.q_values, }) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -161,13 +162,18 @@ def __init__(self, if use_noisy: action_out = self.noisy_layer( "hidden_%d" % i, action_out, hiddens[i], sigma0) - else: + elif parameter_noise: + import tensorflow.contrib.layers as layers action_out = layers.fully_connected( action_out, num_outputs=hiddens[i], activation_fn=tf.nn.relu, - normalizer_fn=layers.layer_norm - if parameter_noise else None) + normalizer_fn=layers.layer_norm) + else: + action_out = tf.layers.dense( + action_out, + units=hiddens[i], + activation=tf.nn.relu) else: # Avoid postprocessing the outputs. This enables custom models # to be used for parametric action DQN. @@ -180,10 +186,8 @@ def __init__(self, sigma0, non_linear=False) elif hiddens: - action_scores = layers.fully_connected( - action_out, - num_outputs=num_actions * num_atoms, - activation_fn=None) + action_scores = tf.layers.dense( + action_out, units=num_actions * num_atoms, activation=None) else: action_scores = model.outputs if num_atoms > 1: @@ -211,13 +215,15 @@ def __init__(self, state_out = self.noisy_layer("dueling_hidden_%d" % i, state_out, hiddens[i], sigma0) - else: - state_out = layers.fully_connected( + elif parameter_noise: + state_out = tf.contrib.layers.fully_connected( state_out, num_outputs=hiddens[i], activation_fn=tf.nn.relu, - normalizer_fn=layers.layer_norm - if parameter_noise else None) + normalizer_fn=tf.contrib.layers.layer_norm) + else: + state_out = tf.layers.dense( + state_out, units=hiddens[i], activation=tf.nn.relu) if use_noisy: state_score = self.noisy_layer( "dueling_output", @@ -226,8 +232,8 @@ def __init__(self, sigma0, non_linear=False) else: - state_score = layers.fully_connected( - state_out, num_outputs=num_atoms, activation_fn=None) + state_score = tf.layers.dense( + state_out, units=num_atoms, activation=None) if num_atoms > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) @@ -263,6 +269,8 @@ def noisy_layer(self, prefix, action_in, out_size, sigma0, distributions and \sigma are trainable variables which are expected to vanish along the training procedure """ + import tensorflow.contrib.layers as layers + in_size = int(action_in.shape[1]) epsilon_in = tf.random_normal(shape=[in_size]) @@ -337,7 +345,7 @@ def __init__(self, q_values, observations, num_actions, stochastic, eps, self.action_prob = None -class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph): +class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) if not isinstance(action_space, Discrete): @@ -438,7 +446,7 @@ def __init__(self, observation_space, action_space, config): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr) - # initialize TFPolicyGraph + # initialize TFPolicy self.sess = tf.get_default_session() self.loss_inputs = [ (SampleBatch.CUR_OBS, self.obs_t), @@ -451,7 +459,7 @@ def __init__(self, observation_space, action_space, config): LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -469,12 +477,12 @@ def __init__(self, observation_space, action_space, config): "cur_lr": tf.cast(self.cur_lr, tf.float64), }, **self.loss.stats) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): return tf.train.AdamOptimizer( learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"]) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): if self.config["grad_norm_clipping"] is not None: grads_and_vars = _minimize_and_clip( @@ -488,27 +496,27 @@ def gradients(self, optimizer, loss): grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] return grads_and_vars - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_feed_dict(self): return { self.stochastic: True, self.eps: self.cur_epsilon, } - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, LEARNER_STATS_KEY: self.stats_fetches, } - @override(PolicyGraph) + @override(Policy) def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] + return [TFPolicy.get_state(self), self.cur_epsilon] - @override(PolicyGraph) + @override(Policy) def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) + TFPolicy.set_state(self, state[0]) self.set_epsilon(state[1]) def _build_parameter_noise(self, pnet_params): @@ -625,25 +633,25 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): rewards[i] += gamma**j * rewards[i + j] -def _postprocess_dqn(policy_graph, batch): +def _postprocess_dqn(policy, batch): # N-step Q adjustments - if policy_graph.config["n_step"] > 1: - _adjust_nstep(policy_graph.config["n_step"], - policy_graph.config["gamma"], batch[SampleBatch.CUR_OBS], - batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS], - batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES]) + if policy.config["n_step"] > 1: + _adjust_nstep(policy.config["n_step"], policy.config["gamma"], + batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS], + batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], + batch[SampleBatch.DONES]) if PRIO_WEIGHTS not in batch: batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS]) # Prioritize on the worker side - if batch.count > 0 and policy_graph.config["worker_side_prioritization"]: - td_errors = policy_graph.compute_td_error( + if batch.count > 0 and policy.config["worker_side_prioritization"]: + td_errors = policy.compute_td_error( batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES], batch[PRIO_WEIGHTS]) new_priorities = ( - np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"]) + np.abs(td_errors) + policy.config["prioritized_replay_eps"]) batch.data[PRIO_WEIGHTS] = new_priorities return batch diff --git a/python/ray/rllib/agents/es/__init__.py b/python/ray/rllib/agents/es/__init__.py index d7bec2a9e002..38b2b772ec57 100644 --- a/python/ray/rllib/agents/es/__init__.py +++ b/python/ray/rllib/agents/es/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG) -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ESAgent = renamed_class(ESTrainer) +ESAgent = renamed_agent(ESTrainer) __all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 2328b90e9ed0..e167129c6a93 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -16,7 +16,7 @@ from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies from ray.rllib.agents.es import utils -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.memory import ray_get_and_free from ray.rllib.utils import FilterManager diff --git a/python/ray/rllib/agents/es/policies.py b/python/ray/rllib/agents/es/policies.py index 78ff29da4f86..dfc7e2deec47 100644 --- a/python/ray/rllib/agents/es/policies.py +++ b/python/ray/rllib/agents/es/policies.py @@ -7,13 +7,15 @@ import gym import numpy as np -import tensorflow as tf import ray import ray.experimental.tf_utils from ray.rllib.evaluation.sampler import _unbatch_tuple_actions from ray.rllib.models import ModelCatalog from ray.rllib.utils.filter import get_filter +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def rollout(policy, env, timestep_limit=None, add_noise=False): diff --git a/python/ray/rllib/agents/es/utils.py b/python/ray/rllib/agents/es/utils.py index 1575e46c3837..518fd3d00634 100644 --- a/python/ray/rllib/agents/es/utils.py +++ b/python/ray/rllib/agents/es/utils.py @@ -6,7 +6,9 @@ from __future__ import print_function import numpy as np -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def compute_ranks(x): diff --git a/python/ray/rllib/agents/impala/__init__.py b/python/ray/rllib/agents/impala/__init__.py index 81c64e8891ab..d7bdd7210fdd 100644 --- a/python/ray/rllib/agents/impala/__init__.py +++ b/python/ray/rllib/agents/impala/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ImpalaAgent = renamed_class(ImpalaTrainer) +ImpalaAgent = renamed_agent(ImpalaTrainer) __all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index ffe74c087a3e..838f2975ce67 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -4,8 +4,8 @@ import time -from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph -from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph +from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy +from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.optimizers import AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator @@ -105,14 +105,14 @@ class ImpalaTrainer(Trainer): _name = "IMPALA" _default_config = DEFAULT_CONFIG - _policy_graph = VTracePolicyGraph + _policy = VTraceTFPolicy @override(Trainer) def _init(self, config, env_creator): for k in OPTIMIZER_SHARED_CONFIGS: if k not in config["optimizer"]: config["optimizer"][k] = config[k] - policy_cls = self._get_policy_graph() + policy_cls = self._get_policy() self.local_evaluator = self.make_local_evaluator( self.env_creator, policy_cls) @@ -158,9 +158,9 @@ def _train(self): prev_steps) return result - def _get_policy_graph(self): + def _get_policy(self): if self.config["vtrace"]: - policy_cls = self._policy_graph + policy_cls = self._policy else: - policy_cls = A3CPolicyGraph + policy_cls = A3CTFPolicy return policy_cls diff --git a/python/ray/rllib/agents/impala/vtrace.py b/python/ray/rllib/agents/impala/vtrace.py index 238b30d99355..67e76929dfc3 100644 --- a/python/ray/rllib/agents/impala/vtrace.py +++ b/python/ray/rllib/agents/impala/vtrace.py @@ -34,9 +34,10 @@ import collections -import tensorflow as tf +from ray.rllib.models.action_dist import Categorical +from ray.rllib.utils import try_import_tf -nest = tf.contrib.framework.nest +tf = try_import_tf() VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [ "vs", "pg_advantages", "log_rhos", "behaviour_action_log_probs", @@ -46,12 +47,15 @@ VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") -def log_probs_from_logits_and_actions(policy_logits, actions): - return multi_log_probs_from_logits_and_actions([policy_logits], - [actions])[0] +def log_probs_from_logits_and_actions(policy_logits, + actions, + dist_class=Categorical): + return multi_log_probs_from_logits_and_actions([policy_logits], [actions], + dist_class)[0] -def multi_log_probs_from_logits_and_actions(policy_logits, actions): +def multi_log_probs_from_logits_and_actions(policy_logits, actions, + dist_class): """Computes action log-probs from policy logits and actions. In the notation used throughout documentation and comments, T refers to the @@ -66,11 +70,11 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions): ..., [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities parameterizing a softmax policy. - actions: A list with length of ACTION_SPACE of int32 + actions: A list with length of ACTION_SPACE of tensors of shapes - [T, B], + [T, B, ...], ..., - [T, B] + [T, B, ...] with actions. Returns: @@ -85,8 +89,16 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions): log_probs = [] for i in range(len(policy_logits)): - log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=policy_logits[i], labels=actions[i])) + p_shape = tf.shape(policy_logits[i]) + a_shape = tf.shape(actions[i]) + policy_logits_flat = tf.reshape(policy_logits[i], + tf.concat([[-1], p_shape[2:]], axis=0)) + actions_flat = tf.reshape(actions[i], + tf.concat([[-1], a_shape[2:]], axis=0)) + log_probs.append( + tf.reshape( + dist_class(policy_logits_flat).logp(actions_flat), + a_shape[:2])) return log_probs @@ -98,6 +110,7 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + dist_class=Categorical, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name="vtrace_from_logits"): @@ -109,6 +122,7 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + dist_class, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, name=name) @@ -131,6 +145,7 @@ def multi_from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + dist_class, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name="vtrace_from_logits"): @@ -166,11 +181,11 @@ def multi_from_logits(behaviour_policy_logits, [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities parameterizing the softmax target policy. - actions: A list with length of ACTION_SPACE of int32 + actions: A list with length of ACTION_SPACE of tensors of shapes - [T, B], + [T, B, ...], ..., - [T, B] + [T, B, ...] with actions sampled from the behaviour policy. discounts: A float32 tensor of shape [T, B] with the discount encountered when following the behaviour policy. @@ -180,6 +195,7 @@ def multi_from_logits(behaviour_policy_logits, wrt. the target policy. bootstrap_value: A float32 of shape [B] with the value function estimate at time T. + dist_class: action distribution class for the logits. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper. @@ -206,13 +222,11 @@ def multi_from_logits(behaviour_policy_logits, behaviour_policy_logits[i], dtype=tf.float32) target_policy_logits[i] = tf.convert_to_tensor( target_policy_logits[i], dtype=tf.float32) - actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32) # Make sure tensor ranks are as expected. # The rest will be checked by from_action_log_probs. behaviour_policy_logits[i].shape.assert_has_rank(3) target_policy_logits[i].shape.assert_has_rank(3) - actions[i].shape.assert_has_rank(2) with tf.name_scope( name, @@ -221,9 +235,9 @@ def multi_from_logits(behaviour_policy_logits, discounts, rewards, values, bootstrap_value ]): target_action_log_probs = multi_log_probs_from_logits_and_actions( - target_policy_logits, actions) + target_policy_logits, actions, dist_class) behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( - behaviour_policy_logits, actions) + behaviour_policy_logits, actions, dist_class) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy.py similarity index 89% rename from python/ray/rllib/agents/impala/vtrace_policy_graph.py rename to python/ray/rllib/agents/impala/vtrace_policy.py index 1798fb7e64ec..7e8867ab2ac0 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -1,6 +1,6 @@ -"""Adapted from A3CPolicyGraph to add V-trace. +"""Adapted from A3CTFPolicy to add V-trace. -Keep in sync with changes to A3CPolicyGraph and VtraceSurrogatePolicyGraph.""" +Keep in sync with changes to A3CTFPolicy and VtraceSurrogatePolicy.""" from __future__ import absolute_import from __future__ import division @@ -9,18 +9,19 @@ import gym import ray import numpy as np -import tensorflow as tf from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() # Frozen logits of the policy that computed the action BEHAVIOUR_LOGITS = "behaviour_logits" @@ -38,6 +39,7 @@ def __init__(self, rewards, values, bootstrap_value, + dist_class, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, @@ -50,7 +52,7 @@ def __init__(self, handle episode cut boundaries. Args: - actions: An int32 tensor of shape [T, B, ACTION_SPACE]. + actions: An int|float32 tensor of shape [T, B, ACTION_SPACE]. actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. @@ -68,6 +70,7 @@ def __init__(self, rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. + dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). """ @@ -76,11 +79,12 @@ def __init__(self, self.vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, - actions=tf.unstack(tf.cast(actions, tf.int32), axis=2), + actions=tf.unstack(actions, axis=2), discounts=tf.to_float(~dones) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, + dist_class=dist_class, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, tf.float32)) @@ -106,13 +110,13 @@ def __init__(self, class VTracePostprocessing(object): """Adds the policy logits to the trajectory.""" - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), + TFPolicy.extra_compute_action_fetches(self), **{BEHAVIOUR_LOGITS: self.model.outputs}) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -122,8 +126,7 @@ def postprocess_trajectory(self, return sample_batch -class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing, - TFPolicyGraph): +class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy): def __init__(self, observation_space, action_space, @@ -138,30 +141,28 @@ def __init__(self, if isinstance(action_space, gym.spaces.Discrete): is_multidiscrete = False - actions_shape = [None] output_hidden_shape = [action_space.n] elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True - actions_shape = [None, len(action_space.nvec)] output_hidden_shape = action_space.nvec.astype(np.int32) else: - raise UnsupportedSpaceException( - "Action space {} is not supported for IMPALA.".format( - action_space)) + is_multidiscrete = False + output_hidden_shape = 1 # Create input placeholders + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) if existing_inputs: actions, dones, behaviour_logits, rewards, observations, \ prev_actions, prev_rewards = existing_inputs[:7] existing_state_in = existing_inputs[7:-1] existing_seq_lens = existing_inputs[-1] else: - actions = tf.placeholder(tf.int64, actions_shape, name="ac") + actions = ModelCatalog.get_action_placeholder(action_space) dones = tf.placeholder(tf.bool, [None], name="dones") rewards = tf.placeholder(tf.float32, [None], name="rewards") behaviour_logits = tf.placeholder( - tf.float32, [None, sum(output_hidden_shape)], - name="behaviour_logits") + tf.float32, [None, logit_dim], name="behaviour_logits") observations = tf.placeholder( tf.float32, [None] + list(observation_space.shape), name='observations') @@ -173,8 +174,6 @@ def __init__(self, behaviour_logits, output_hidden_shape, axis=1) # Setup the policy - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) prev_actions = ModelCatalog.get_action_placeholder(action_space) prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") self.model = ModelCatalog.get_model( @@ -266,6 +265,7 @@ def make_time_major(tensor, drop_last=False): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], + dist_class=dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], @@ -285,17 +285,14 @@ def make_time_major(tensor, drop_last=False): self.KL_stats.update({ "mean_KL_{}".format(i): tf.reduce_mean(kl), "max_KL_{}".format(i): tf.reduce_max(kl), - "median_KL_{}".format(i): tf.contrib.distributions. - percentile(kl, 50.0), }) else: self.KL_stats = { "mean_KL": tf.reduce_mean(kls[0]), "max_KL": tf.reduce_max(kls[0]), - "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0), } - # Initialize TFPolicyGraph + # Initialize TFPolicy loss_in = [ (SampleBatch.ACTIONS, actions), (SampleBatch.DONES, dones), @@ -311,7 +308,7 @@ def make_time_major(tensor, drop_last=False): with tf.name_scope('TFPolicyGraph.__init__'): self.state_values = values - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -347,15 +344,15 @@ def make_time_major(tensor, drop_last=False): }, **self.KL_stats), } - @override(TFPolicyGraph) + @override(TFPolicy) def copy(self, existing_inputs): - return VTracePolicyGraph( + return VTraceTFPolicy( self.observation_space, self.action_space, self.config, existing_inputs=existing_inputs) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): if self.config["opt_type"] == "adam": return tf.train.AdamOptimizer(self.cur_lr) @@ -364,7 +361,7 @@ def optimizer(self): self.config["momentum"], self.config["epsilon"]) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): grads = tf.gradients(loss, self.var_list) return self._clip_grads(grads) @@ -374,10 +371,10 @@ def _clip_grads(self, grads): clipped_grads = list(zip(self.grads, self.var_list)) return clipped_grads - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return self.stats_fetches - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/impala/vtrace_test.py b/python/ray/rllib/agents/impala/vtrace_test.py index 145ed4e7a2cd..e1f39991b097 100644 --- a/python/ray/rllib/agents/impala/vtrace_test.py +++ b/python/ray/rllib/agents/impala/vtrace_test.py @@ -26,8 +26,10 @@ from absl.testing import parameterized import numpy as np -import tensorflow as tf import vtrace +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def _shaped_arange(*shape): diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index b1e535b64530..d6c6eadeaa9c 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -3,7 +3,7 @@ from __future__ import print_function from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph +from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer from ray.rllib.utils.annotations import override @@ -44,14 +44,14 @@ class MARWILTrainer(Trainer): _name = "MARWIL" _default_config = DEFAULT_CONFIG - _policy_graph = MARWILPolicyGraph + _policy = MARWILPolicy @override(Trainer) def _init(self, config, env_creator): self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) + env_creator, self._policy) self.remote_evaluators = self.make_remote_evaluators( - env_creator, self._policy_graph, config["num_workers"]) + env_creator, self._policy, config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( self.local_evaluator, self.remote_evaluators, diff --git a/python/ray/rllib/agents/marwil/marwil_policy_graph.py b/python/ray/rllib/agents/marwil/marwil_policy.py similarity index 92% rename from python/ray/rllib/agents/marwil/marwil_policy_graph.py rename to python/ray/rllib/agents/marwil/marwil_policy.py index 2dd67ab5f39c..add021025c9c 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy_graph.py +++ b/python/ray/rllib/agents/marwil/marwil_policy.py @@ -2,19 +2,20 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf - import ray from ray.rllib.models import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.utils.annotations import override -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.agents.dqn.dqn_policy import _scope_vars from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() POLICY_SCOPE = "p_func" VALUE_SCOPE = "v_func" @@ -58,7 +59,7 @@ def __init__(self, state_values, cumulative_rewards, logits, actions, class MARWILPostprocessing(object): """Adds the advantages field to the trajectory.""" - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -78,7 +79,7 @@ def postprocess_trajectory(self, return batch -class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph): +class MARWILPolicy(MARWILPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) self.config = config @@ -126,14 +127,14 @@ def __init__(self, observation_space, action_space, config): self.explained_variance = tf.reduce_mean( explained_variance(self.cum_rew_t, state_values)) - # initialize TFPolicyGraph + # initialize TFPolicy self.sess = tf.get_default_session() self.loss_inputs = [ (SampleBatch.CUR_OBS, self.obs_t), (SampleBatch.ACTIONS, self.act_t), (Postprocessing.ADVANTAGES, self.cum_rew_t), ] - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -165,10 +166,10 @@ def _build_policy_loss(self, state_values, cum_rwds, logits, actions, return ReweightedImitationLoss(state_values, cum_rwds, logits, actions, action_space, self.config["beta"]) - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return {LEARNER_STATS_KEY: self.stats_fetches} - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/pg/__init__.py b/python/ray/rllib/agents/pg/__init__.py index 2203188a7ca6..eb11c99bf625 100644 --- a/python/ray/rllib/agents/pg/__init__.py +++ b/python/ray/rllib/agents/pg/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -PGAgent = renamed_class(PGTrainer) +PGAgent = renamed_agent(PGTrainer) __all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index e70fdcc8b2c6..299cdcac3de4 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -2,11 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph - -from ray.rllib.optimizers import SyncSamplesOptimizer -from ray.rllib.utils.annotations import override +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.agents.pg.pg_policy import PGTFPolicy # yapf: disable # __sphinx_doc_begin__ @@ -22,40 +20,16 @@ # yapf: enable -class PGTrainer(Trainer): - """Simple policy gradient agent. - - This is an example agent to show how to implement algorithms in RLlib. - In most cases, you will probably want to use the PPO agent instead. - """ - - _name = "PG" - _default_config = DEFAULT_CONFIG - _policy_graph = PGPolicyGraph +def get_policy_class(config): + if config["use_pytorch"]: + from ray.rllib.agents.pg.torch_pg_policy import PGTorchPolicy + return PGTorchPolicy + else: + return PGTFPolicy - @override(Trainer) - def _init(self, config, env_creator): - if config["use_pytorch"]: - from ray.rllib.agents.pg.torch_pg_policy_graph import \ - PGTorchPolicyGraph - policy_cls = PGTorchPolicyGraph - else: - policy_cls = self._policy_graph - self.local_evaluator = self.make_local_evaluator( - env_creator, policy_cls) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy_cls, config["num_workers"]) - optimizer_config = dict( - config["optimizer"], - **{"train_batch_size": config["train_batch_size"]}) - self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, self.remote_evaluators, **optimizer_config) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result +PGTrainer = build_trainer( + name="PGTrainer", + default_config=DEFAULT_CONFIG, + default_policy=PGTFPolicy, + get_policy_class=get_policy_class) diff --git a/python/ray/rllib/agents/pg/pg_policy.py b/python/ray/rllib/agents/pg/pg_policy.py new file mode 100644 index 000000000000..7cca613928fb --- /dev/null +++ b/python/ray/rllib/agents/pg/pg_policy.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +from ray.rllib.evaluation.postprocessing import compute_advantages, \ + Postprocessing +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + + +# The basic policy gradients loss +def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + advantages = batch_tensors[Postprocessing.ADVANTAGES] + return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages) + + +# This adds the "advantages" column to the sample batch. +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) + + +PGTFPolicy = build_tf_policy( + name="PGTFPolicy", + get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, + postprocess_fn=postprocess_advantages, + loss_fn=policy_gradient_loss) diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py deleted file mode 100644 index 6e8abd7d4a81..000000000000 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import ray -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.evaluation.postprocessing import compute_advantages, \ - Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.utils.annotations import override - - -class PGLoss(object): - """The basic policy gradient loss.""" - - def __init__(self, action_dist, actions, advantages): - self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages) - - -class PGPostprocessing(object): - """Adds the advantages field to the trajectory.""" - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - # This adds the "advantages" column to the sample batch - return compute_advantages( - sample_batch, 0.0, self.config["gamma"], use_gae=False) - - -class PGPolicyGraph(PGPostprocessing, TFPolicyGraph): - """Simple policy gradient example of defining a policy graph.""" - - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config) - self.config = config - - # Setup placeholders - obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") - - # Create the model network and action outputs - self.model = ModelCatalog.get_model({ - "obs": obs, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, self.logit_dim, self.config["model"]) - action_dist = dist_class(self.model.outputs) # logit for each action - - # Setup policy loss - actions = ModelCatalog.get_action_placeholder(action_space) - advantages = tf.placeholder(tf.float32, [None], name="adv") - loss = PGLoss(action_dist, actions, advantages).loss - - # Mapping from sample batch keys to placeholders. These keys will be - # read from postprocessed sample batches and fed into the specified - # placeholders during loss computation. - loss_in = [ - (SampleBatch.CUR_OBS, obs), - (SampleBatch.ACTIONS, actions), - (SampleBatch.PREV_ACTIONS, prev_actions), - (SampleBatch.PREV_REWARDS, prev_rewards), - (Postprocessing.ADVANTAGES, advantages), - ] - - # Initialize TFPolicyGraph - sess = tf.get_default_session() - TFPolicyGraph.__init__( - self, - obs_space, - action_space, - sess, - obs_input=obs, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=loss, - loss_inputs=loss_in, - model=self.model, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions, - prev_reward_input=prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=config["model"]["max_seq_len"]) - sess.run(tf.global_variables_initializer()) - - @override(PolicyGraph) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicyGraph) - def optimizer(self): - return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) diff --git a/python/ray/rllib/agents/pg/torch_pg_policy.py b/python/ray/rllib/agents/pg/torch_pg_policy.py new file mode 100644 index 000000000000..d0f1cda71cc7 --- /dev/null +++ b/python/ray/rllib/agents/pg/torch_pg_policy.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +from ray.rllib.evaluation.postprocessing import compute_advantages, \ + Postprocessing +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy + + +def pg_torch_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + # save the error in the policy object + policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( + log_probs.reshape(-1)) + return policy.pi_err + + +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) + + +def pg_loss_stats(policy, batch_tensors): + # the error is recorded when computing the loss + return {"policy_loss": policy.pi_err.item()} + + +PGTorchPolicy = build_torch_policy( + name="PGTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=pg_torch_loss, + stats_fn=pg_loss_stats, + postprocess_fn=postprocess_advantages) diff --git a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py deleted file mode 100644 index 746ef1bca42f..000000000000 --- a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import torch -from torch import nn - -import ray -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.evaluation.postprocessing import compute_advantages, \ - Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.utils.annotations import override - - -class PGLoss(nn.Module): - def __init__(self, dist_class): - nn.Module.__init__(self) - self.dist_class = dist_class - - def forward(self, policy_model, observations, actions, advantages): - logits, _, values, _ = policy_model({ - SampleBatch.CUR_OBS: observations - }, []) - dist = self.dist_class(logits) - log_probs = dist.logp(actions) - self.pi_err = -advantages.dot(log_probs.reshape(-1)) - return self.pi_err - - -class PGPostprocessing(object): - """Adds the value func output and advantages field to the trajectory.""" - - @override(TorchPolicyGraph) - def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - return compute_advantages( - sample_batch, 0.0, self.config["gamma"], use_gae=False) - - -class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph): - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], torch=True) - model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = PGLoss(dist_class) - - TorchPolicyGraph.__init__( - self, - obs_space, - action_space, - model, - loss, - loss_inputs=[ - SampleBatch.CUR_OBS, SampleBatch.ACTIONS, - Postprocessing.ADVANTAGES - ], - action_distribution_cls=dist_class) - - @override(TorchPolicyGraph) - def optimizer(self): - return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) - - @override(TorchPolicyGraph) - def extra_grad_info(self): - return {"policy_loss": self._loss.pi_err.item()} - - def _value(self, obs): - with self.lock: - obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self.model({"obs": obs}, []) - return vf.detach().cpu().numpy().squeeze() diff --git a/python/ray/rllib/agents/ppo/__init__.py b/python/ray/rllib/agents/ppo/__init__.py index a02cbc23c684..a3d492baf24a 100644 --- a/python/ray/rllib/agents/ppo/__init__.py +++ b/python/ray/rllib/agents/ppo/__init__.py @@ -1,7 +1,7 @@ from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG from ray.rllib.agents.ppo.appo import APPOTrainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -PPOAgent = renamed_class(PPOTrainer) +PPOAgent = renamed_agent(PPOTrainer) __all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index ac3251775d52..0438b2714221 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph +from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala from ray.rllib.utils.annotations import override @@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer): _name = "APPO" _default_config = DEFAULT_CONFIG - _policy_graph = AsyncPPOPolicyGraph + _policy = AsyncPPOTFPolicy @override(impala.ImpalaTrainer) - def _get_policy_graph(self): - return AsyncPPOPolicyGraph + def _get_policy(self): + return AsyncPPOTFPolicy diff --git a/python/ray/rllib/agents/ppo/appo_policy.py b/python/ray/rllib/agents/ppo/appo_policy.py new file mode 100644 index 000000000000..b740d6d81430 --- /dev/null +++ b/python/ray/rllib/agents/ppo/appo_policy.py @@ -0,0 +1,396 @@ +"""Adapted from VTraceTFPolicy to use the PPO surrogate loss. + +Keep in sync with changes to VTraceTFPolicy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import logging +import gym + +import ray +from ray.rllib.agents.impala import vtrace +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.tf_policy import LearningRateSchedule +from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +logger = logging.getLogger(__name__) + +BEHAVIOUR_LOGITS = "behaviour_logits" + + +class PPOSurrogateLoss(object): + """Loss used when V-trace is disabled. + + Arguments: + prev_actions_logp: A float32 tensor of shape [T, B]. + actions_logp: A float32 tensor of shape [T, B]. + action_kl: A float32 tensor of shape [T, B]. + actions_entropy: A float32 tensor of shape [T, B]. + values: A float32 tensor of shape [T, B]. + valid_mask: A bool tensor of valid RNN input elements (#2992). + advantages: A float32 tensor of shape [T, B]. + value_targets: A float32 tensor of shape [T, B]. + """ + + def __init__(self, + prev_actions_logp, + actions_logp, + action_kl, + actions_entropy, + values, + valid_mask, + advantages, + value_targets, + vf_loss_coeff=0.5, + entropy_coeff=0.01, + clip_param=0.3): + + logp_ratio = tf.exp(actions_logp - prev_actions_logp) + + surrogate_loss = tf.minimum( + advantages * logp_ratio, + advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, + 1 + clip_param)) + + self.mean_kl = tf.reduce_mean(action_kl) + self.pi_loss = -tf.reduce_sum(surrogate_loss) + + # The baseline loss + delta = tf.boolean_mask(values - value_targets, valid_mask) + self.value_targets = value_targets + self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) + + # The entropy loss + self.entropy = tf.reduce_sum( + tf.boolean_mask(actions_entropy, valid_mask)) + + # The summed weighted loss + self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - + self.entropy * entropy_coeff) + + +class VTraceSurrogateLoss(object): + def __init__(self, + actions, + prev_actions_logp, + actions_logp, + action_kl, + actions_entropy, + dones, + behaviour_logits, + target_logits, + discount, + rewards, + values, + bootstrap_value, + dist_class, + valid_mask, + vf_loss_coeff=0.5, + entropy_coeff=0.01, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, + clip_param=0.3): + """PPO surrogate loss with vtrace importance weighting. + + VTraceLoss takes tensors of shape [T, B, ...], where `B` is the + batch_size. The reason we need to know `B` is for V-trace to properly + handle episode cut boundaries. + + Arguments: + actions: An int|float32 tensor of shape [T, B, logit_dim]. + prev_actions_logp: A float32 tensor of shape [T, B]. + actions_logp: A float32 tensor of shape [T, B]. + action_kl: A float32 tensor of shape [T, B]. + actions_entropy: A float32 tensor of shape [T, B]. + dones: A bool tensor of shape [T, B]. + behaviour_logits: A float32 tensor of shape [T, B, logit_dim]. + target_logits: A float32 tensor of shape [T, B, logit_dim]. + discount: A float32 scalar. + rewards: A float32 tensor of shape [T, B]. + values: A float32 tensor of shape [T, B]. + bootstrap_value: A float32 tensor of shape [B]. + dist_class: action distribution class for logits. + valid_mask: A bool tensor of valid RNN input elements (#2992). + """ + + # Compute vtrace on the CPU for better perf. + with tf.device("/cpu:0"): + self.vtrace_returns = vtrace.multi_from_logits( + behaviour_policy_logits=behaviour_logits, + target_policy_logits=target_logits, + actions=tf.unstack(actions, axis=2), + discounts=tf.to_float(~dones) * discount, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + dist_class=dist_class, + clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), + clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, + tf.float32)) + + logp_ratio = tf.exp(actions_logp - prev_actions_logp) + + advantages = self.vtrace_returns.pg_advantages + surrogate_loss = tf.minimum( + advantages * logp_ratio, + advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, + 1 + clip_param)) + + self.mean_kl = tf.reduce_mean(action_kl) + self.pi_loss = -tf.reduce_sum(surrogate_loss) + + # The baseline loss + delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) + self.value_targets = self.vtrace_returns.vs + self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) + + # The entropy loss + self.entropy = tf.reduce_sum( + tf.boolean_mask(actions_entropy, valid_mask)) + + # The summed weighted loss + self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - + self.entropy * entropy_coeff) + + +def _make_time_major(policy, tensor, drop_last=False): + """Swaps batch and trajectory axis. + + Arguments: + policy: Policy reference + tensor: A tensor or list of tensors to reshape. + drop_last: A bool indicating whether to drop the last + trajectory item. + + Returns: + res: A tensor with swapped axes or a list of tensors with + swapped axes. + """ + if isinstance(tensor, list): + return [_make_time_major(policy, t, drop_last) for t in tensor] + + if policy.model.state_init: + B = tf.shape(policy.model.seq_lens)[0] + T = tf.shape(tensor)[0] // B + else: + # Important: chop the tensor into batches at known episode cut + # boundaries. TODO(ekl) this is kind of a hack + T = policy.config["sample_batch_size"] + B = tf.shape(tensor)[0] // T + rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) + + # swap B and T axes + res = tf.transpose( + rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) + + if drop_last: + return res[:-1] + return res + + +def build_appo_surrogate_loss(policy, batch_tensors): + if isinstance(policy.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [policy.action_space.n] + elif isinstance(policy.action_space, + gym.spaces.multi_discrete.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = policy.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + def make_time_major(*args, **kw): + return _make_time_major(policy, *args, **kw) + + actions = batch_tensors[SampleBatch.ACTIONS] + dones = batch_tensors[SampleBatch.DONES] + rewards = batch_tensors[SampleBatch.REWARDS] + behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS] + unpacked_behaviour_logits = tf.split( + behaviour_logits, output_hidden_shape, axis=1) + unpacked_outputs = tf.split( + policy.model.outputs, output_hidden_shape, axis=1) + prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ + behaviour_logits + action_dist = policy.action_dist + prev_action_dist = policy.dist_class(prev_dist_inputs) + values = policy.value_function + + if policy.model.state_in: + max_seq_len = tf.reduce_max(policy.model.seq_lens) - 1 + mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(rewards) + + if policy.config["vtrace"]: + logger.info("Using V-Trace surrogate loss (vtrace=True)") + + # Prepare actions for loss + loss_actions = actions if is_multidiscrete else tf.expand_dims( + actions, axis=1) + + policy.loss = VTraceSurrogateLoss( + actions=make_time_major(loss_actions, drop_last=True), + prev_actions_logp=make_time_major( + prev_action_dist.logp(actions), drop_last=True), + actions_logp=make_time_major( + action_dist.logp(actions), drop_last=True), + action_kl=prev_action_dist.kl(action_dist), + actions_entropy=make_time_major( + action_dist.entropy(), drop_last=True), + dones=make_time_major(dones, drop_last=True), + behaviour_logits=make_time_major( + unpacked_behaviour_logits, drop_last=True), + target_logits=make_time_major(unpacked_outputs, drop_last=True), + discount=policy.config["gamma"], + rewards=make_time_major(rewards, drop_last=True), + values=make_time_major(values, drop_last=True), + bootstrap_value=make_time_major(values)[-1], + dist_class=policy.dist_class, + valid_mask=make_time_major(mask, drop_last=True), + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.config["entropy_coeff"], + clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=policy.config[ + "vtrace_clip_pg_rho_threshold"], + clip_param=policy.config["clip_param"]) + else: + logger.info("Using PPO surrogate loss (vtrace=False)") + policy.loss = PPOSurrogateLoss( + prev_actions_logp=make_time_major(prev_action_dist.logp(actions)), + actions_logp=make_time_major(action_dist.logp(actions)), + action_kl=prev_action_dist.kl(action_dist), + actions_entropy=make_time_major(action_dist.entropy()), + values=make_time_major(values), + valid_mask=make_time_major(mask), + advantages=make_time_major( + batch_tensors[Postprocessing.ADVANTAGES]), + value_targets=make_time_major( + batch_tensors[Postprocessing.VALUE_TARGETS]), + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"]) + + return policy.loss.total_loss + + +def stats(policy, batch_tensors): + values_batched = _make_time_major( + policy, policy.value_function, drop_last=policy.config["vtrace"]) + + return { + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "policy_loss": policy.loss.pi_loss, + "entropy": policy.loss.entropy, + "var_gnorm": tf.global_norm(policy.var_list), + "vf_loss": policy.loss.vf_loss, + "vf_explained_var": explained_variance( + tf.reshape(policy.loss.value_targets, [-1]), + tf.reshape(values_batched, [-1])), + } + + +def grad_stats(policy, grads): + return { + "grad_gnorm": tf.global_norm(grads), + } + + +def postprocess_trajectory(policy, + sample_batch, + other_agent_batches=None, + episode=None): + if not policy.config["vtrace"]: + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 + else: + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy.value(sample_batch["new_obs"][-1], *next_state) + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"]) + else: + batch = sample_batch + del batch.data["new_obs"] # not used, so save some bandwidth + return batch + + +def add_values_and_logits(policy): + out = {BEHAVIOUR_LOGITS: policy.model.outputs} + if not policy.config["vtrace"]: + out[SampleBatch.VF_PREDS] = policy.value_function + return out + + +def validate_config(policy, obs_space, action_space, config): + assert config["batch_mode"] == "truncate_episodes", \ + "Must use `truncate_episodes` batch mode with V-trace." + + +def choose_optimizer(policy, config): + if policy.config["opt_type"] == "adam": + return tf.train.AdamOptimizer(policy.cur_lr) + else: + return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"], + config["momentum"], config["epsilon"]) + + +def clip_gradients(policy, optimizer, loss): + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + + +class ValueNetworkMixin(object): + def __init__(self): + self.value_function = self.model.value_function() + self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + + def value(self, ob, *args): + feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]} + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): + feed_dict[k] = v + vf = self._sess.run(self.value_function, feed_dict) + return vf[0] + + +def setup_mixins(policy, obs_space, action_space, config): + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + ValueNetworkMixin.__init__(policy) + + +AsyncPPOTFPolicy = build_tf_policy( + name="AsyncPPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, + loss_fn=build_appo_surrogate_loss, + stats_fn=stats, + grad_stats_fn=grad_stats, + postprocess_fn=postprocess_trajectory, + optimizer_fn=choose_optimizer, + gradients_fn=clip_gradients, + extra_action_fetches_fn=add_values_and_logits, + before_init=validate_config, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, ValueNetworkMixin], + get_batch_divisibility_req=lambda p: p.config["sample_batch_size"]) diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py deleted file mode 100644 index 89e49153f90c..000000000000 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ /dev/null @@ -1,497 +0,0 @@ -"""Adapted from VTracePolicyGraph to use the PPO surrogate loss. - -Keep in sync with changes to VTracePolicyGraph.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -import logging -import gym - -import ray -from ray.rllib.agents.impala import vtrace -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ - LearningRateSchedule -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override -from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.models.action_dist import MultiCategorical -from ray.rllib.evaluation.postprocessing import compute_advantages - -logger = logging.getLogger(__name__) - - -class PPOSurrogateLoss(object): - """Loss used when V-trace is disabled. - - Arguments: - prev_actions_logp: A float32 tensor of shape [T, B]. - actions_logp: A float32 tensor of shape [T, B]. - action_kl: A float32 tensor of shape [T, B]. - actions_entropy: A float32 tensor of shape [T, B]. - values: A float32 tensor of shape [T, B]. - valid_mask: A bool tensor of valid RNN input elements (#2992). - advantages: A float32 tensor of shape [T, B]. - value_targets: A float32 tensor of shape [T, B]. - """ - - def __init__(self, - prev_actions_logp, - actions_logp, - action_kl, - actions_entropy, - values, - valid_mask, - advantages, - value_targets, - vf_loss_coeff=0.5, - entropy_coeff=0.01, - clip_param=0.3): - - logp_ratio = tf.exp(actions_logp - prev_actions_logp) - - surrogate_loss = tf.minimum( - advantages * logp_ratio, - advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, - 1 + clip_param)) - - self.mean_kl = tf.reduce_mean(action_kl) - self.pi_loss = -tf.reduce_sum(surrogate_loss) - - # The baseline loss - delta = tf.boolean_mask(values - value_targets, valid_mask) - self.value_targets = value_targets - self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) - - # The entropy loss - self.entropy = tf.reduce_sum( - tf.boolean_mask(actions_entropy, valid_mask)) - - # The summed weighted loss - self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - - self.entropy * entropy_coeff) - - -class VTraceSurrogateLoss(object): - def __init__(self, - actions, - prev_actions_logp, - actions_logp, - action_kl, - actions_entropy, - dones, - behaviour_logits, - target_logits, - discount, - rewards, - values, - bootstrap_value, - valid_mask, - vf_loss_coeff=0.5, - entropy_coeff=0.01, - clip_rho_threshold=1.0, - clip_pg_rho_threshold=1.0, - clip_param=0.3): - """PPO surrogate loss with vtrace importance weighting. - - VTraceLoss takes tensors of shape [T, B, ...], where `B` is the - batch_size. The reason we need to know `B` is for V-trace to properly - handle episode cut boundaries. - - Arguments: - actions: An int32 tensor of shape [T, B, NUM_ACTIONS]. - prev_actions_logp: A float32 tensor of shape [T, B]. - actions_logp: A float32 tensor of shape [T, B]. - action_kl: A float32 tensor of shape [T, B]. - actions_entropy: A float32 tensor of shape [T, B]. - dones: A bool tensor of shape [T, B]. - behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. - target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. - discount: A float32 scalar. - rewards: A float32 tensor of shape [T, B]. - values: A float32 tensor of shape [T, B]. - bootstrap_value: A float32 tensor of shape [B]. - valid_mask: A bool tensor of valid RNN input elements (#2992). - """ - - # Compute vtrace on the CPU for better perf. - with tf.device("/cpu:0"): - self.vtrace_returns = vtrace.multi_from_logits( - behaviour_policy_logits=behaviour_logits, - target_policy_logits=target_logits, - actions=tf.unstack(tf.cast(actions, tf.int32), axis=2), - discounts=tf.to_float(~dones) * discount, - rewards=rewards, - values=values, - bootstrap_value=bootstrap_value, - clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), - clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, - tf.float32)) - - logp_ratio = tf.exp(actions_logp - prev_actions_logp) - - advantages = self.vtrace_returns.pg_advantages - surrogate_loss = tf.minimum( - advantages * logp_ratio, - advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, - 1 + clip_param)) - - self.mean_kl = tf.reduce_mean(action_kl) - self.pi_loss = -tf.reduce_sum(surrogate_loss) - - # The baseline loss - delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) - self.value_targets = self.vtrace_returns.vs - self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) - - # The entropy loss - self.entropy = tf.reduce_sum( - tf.boolean_mask(actions_entropy, valid_mask)) - - # The summed weighted loss - self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - - self.entropy * entropy_coeff) - - -class APPOPostprocessing(object): - """Adds the policy logits, VF preds, and advantages to the trajectory.""" - - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - out = {"behaviour_logits": self.model.outputs} - if not self.config["vtrace"]: - out["vf_preds"] = self.value_function - return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - if not self.config["vtrace"]: - completed = sample_batch["dones"][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append( - [sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) - batch = compute_advantages( - sample_batch, - last_r, - self.config["gamma"], - self.config["lambda"], - use_gae=self.config["use_gae"]) - else: - batch = sample_batch - del batch.data["new_obs"] # not used, so save some bandwidth - return batch - - -class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing, - TFPolicyGraph): - def __init__(self, - observation_space, - action_space, - config, - existing_inputs=None): - config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) - assert config["batch_mode"] == "truncate_episodes", \ - "Must use `truncate_episodes` batch mode with V-trace." - self.config = config - self.sess = tf.get_default_session() - self.grads = None - - if isinstance(action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [action_space.n] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = action_space.nvec.astype(np.int32) - elif self.config["vtrace"]: - raise UnsupportedSpaceException( - "Action space {} is not supported for APPO + VTrace.", - format(action_space)) - else: - is_multidiscrete = False - output_hidden_shape = 1 - - # Policy network model - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - - # Create input placeholders - if existing_inputs: - if self.config["vtrace"]: - actions, dones, behaviour_logits, rewards, observations, \ - prev_actions, prev_rewards = existing_inputs[:7] - existing_state_in = existing_inputs[7:-1] - existing_seq_lens = existing_inputs[-1] - else: - actions, dones, behaviour_logits, rewards, observations, \ - prev_actions, prev_rewards, adv_ph, value_targets = \ - existing_inputs[:9] - existing_state_in = existing_inputs[9:-1] - existing_seq_lens = existing_inputs[-1] - else: - actions = ModelCatalog.get_action_placeholder(action_space) - dones = tf.placeholder(tf.bool, [None], name="dones") - rewards = tf.placeholder(tf.float32, [None], name="rewards") - behaviour_logits = tf.placeholder( - tf.float32, [None, logit_dim], name="behaviour_logits") - observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) - existing_state_in = None - existing_seq_lens = None - - if not self.config["vtrace"]: - adv_ph = tf.placeholder( - tf.float32, name="advantages", shape=(None, )) - value_targets = tf.placeholder( - tf.float32, name="value_targets", shape=(None, )) - self.observations = observations - - # Unpack behaviour logits - unpacked_behaviour_logits = tf.split( - behaviour_logits, output_hidden_shape, axis=1) - - # Setup the policy - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") - self.model = ModelCatalog.get_model( - { - "obs": observations, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, - observation_space, - action_space, - logit_dim, - self.config["model"], - state_in=existing_state_in, - seq_lens=existing_seq_lens) - unpacked_outputs = tf.split( - self.model.outputs, output_hidden_shape, axis=1) - - dist_inputs = unpacked_outputs if is_multidiscrete else \ - self.model.outputs - prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ - behaviour_logits - - action_dist = dist_class(dist_inputs) - prev_action_dist = dist_class(prev_dist_inputs) - - values = self.model.value_function() - self.value_function = values - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - - def make_time_major(tensor, drop_last=False): - """Swaps batch and trajectory axis. - Args: - tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last - trajectory item. - Returns: - res: A tensor with swapped axes or a list of tensors with - swapped axes. - """ - if isinstance(tensor, list): - return [make_time_major(t, drop_last) for t in tensor] - - if self.model.state_init: - B = tf.shape(self.model.seq_lens)[0] - T = tf.shape(tensor)[0] // B - else: - # Important: chop the tensor into batches at known episode cut - # boundaries. TODO(ekl) this is kind of a hack - T = self.config["sample_batch_size"] - B = tf.shape(tensor)[0] // T - rs = tf.reshape(tensor, - tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) - - # swap B and T axes - res = tf.transpose( - rs, - [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - - if drop_last: - return res[:-1] - return res - - if self.model.state_in: - max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 - mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(rewards) - - # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. - if self.config["vtrace"]: - logger.info("Using V-Trace surrogate loss (vtrace=True)") - - # Prepare actions for loss - loss_actions = actions if is_multidiscrete else tf.expand_dims( - actions, axis=1) - - self.loss = VTraceSurrogateLoss( - actions=make_time_major(loss_actions, drop_last=True), - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions), drop_last=True), - actions_logp=make_time_major( - action_dist.logp(actions), drop_last=True), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major( - action_dist.entropy(), drop_last=True), - dones=make_time_major(dones, drop_last=True), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=True), - target_logits=make_time_major( - unpacked_outputs, drop_last=True), - discount=config["gamma"], - rewards=make_time_major(rewards, drop_last=True), - values=make_time_major(values, drop_last=True), - bootstrap_value=make_time_major(values)[-1], - valid_mask=make_time_major(mask, drop_last=True), - vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - clip_pg_rho_threshold=self.config[ - "vtrace_clip_pg_rho_threshold"], - clip_param=self.config["clip_param"]) - else: - logger.info("Using PPO surrogate loss (vtrace=False)") - self.loss = PPOSurrogateLoss( - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions)), - actions_logp=make_time_major(action_dist.logp(actions)), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major(action_dist.entropy()), - values=make_time_major(values), - valid_mask=make_time_major(mask), - advantages=make_time_major(adv_ph), - value_targets=make_time_major(value_targets), - vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], - clip_param=self.config["clip_param"]) - - # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) - - kls = model_dist.kl(behaviour_dist) - if len(kls) > 1: - self.KL_stats = {} - - for i, kl in enumerate(kls): - self.KL_stats.update({ - "mean_KL_{}".format(i): tf.reduce_mean(kl), - "max_KL_{}".format(i): tf.reduce_max(kl), - "median_KL_{}".format(i): tf.contrib.distributions. - percentile(kl, 50.0), - }) - else: - self.KL_stats = { - "mean_KL": tf.reduce_mean(kls[0]), - "max_KL": tf.reduce_max(kls[0]), - "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0), - } - - # Initialize TFPolicyGraph - loss_in = [ - ("actions", actions), - ("dones", dones), - ("behaviour_logits", behaviour_logits), - ("rewards", rewards), - ("obs", observations), - ("prev_actions", prev_actions), - ("prev_rewards", prev_rewards), - ] - if not self.config["vtrace"]: - loss_in.append(("advantages", adv_ph)) - loss_in.append(("value_targets", value_targets)) - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=observations, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=self.loss.total_loss, - model=self.model, - loss_inputs=loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions, - prev_reward_input=prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=self.config["model"]["max_seq_len"], - batch_divisibility_req=self.config["sample_batch_size"]) - - self.sess.run(tf.global_variables_initializer()) - - values_batched = make_time_major( - values, drop_last=self.config["vtrace"]) - self.stats_fetches = { - LEARNER_STATS_KEY: dict({ - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.loss.pi_loss, - "entropy": self.loss.entropy, - "grad_gnorm": tf.global_norm(self._grads), - "var_gnorm": tf.global_norm(self.var_list), - "vf_loss": self.loss.vf_loss, - "vf_explained_var": explained_variance( - tf.reshape(self.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1])), - }, **self.KL_stats), - } - - def optimizer(self): - if self.config["opt_type"] == "adam": - return tf.train.AdamOptimizer(self.cur_lr) - else: - return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"], - self.config["momentum"], - self.config["epsilon"]) - - def gradients(self, optimizer, loss): - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - def extra_compute_grad_fetches(self): - return self.stats_fetches - - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} - assert len(args) == len(self.model.state_in), \ - (args, self.model.state_in) - for k, v in zip(self.model.state_in, args): - feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) - return vf[0] - - def get_initial_state(self): - return self.model.state_init - - def copy(self, existing_inputs): - return AsyncPPOPolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 8f69c91149e7..daf43d14821d 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -4,10 +4,10 @@ import logging -from ray.rllib.agents import Trainer, with_common_config -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents import with_common_config +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer -from ray.rllib.utils.annotations import override logger = logging.getLogger(__name__) @@ -63,110 +63,103 @@ # yapf: enable -class PPOTrainer(Trainer): - """Multi-GPU optimized implementation of PPO in TensorFlow.""" - - _name = "PPO" - _default_config = DEFAULT_CONFIG - _policy_graph = PPOPolicyGraph - - @override(Trainer) - def _init(self, config, env_creator): - self._validate_config() - self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, self._policy_graph, config["num_workers"]) - if config["simple_optimizer"]: - self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, - self.remote_evaluators, - num_sgd_iter=config["num_sgd_iter"], - train_batch_size=config["train_batch_size"]) - else: - self.optimizer = LocalMultiGPUOptimizer( - self.local_evaluator, - self.remote_evaluators, - sgd_batch_size=config["sgd_minibatch_size"], - num_sgd_iter=config["num_sgd_iter"], - num_gpus=config["num_gpus"], - sample_batch_size=config["sample_batch_size"], - num_envs_per_worker=config["num_envs_per_worker"], - train_batch_size=config["train_batch_size"], - standardize_fields=["advantages"], - straggler_mitigation=config["straggler_mitigation"]) - - @override(Trainer) - def _train(self): - if "observation_filter" not in self.raw_user_config: - # TODO(ekl) remove this message after a few releases - logger.info( - "Important! Since 0.7.0, observation normalization is no " - "longer enabled by default. To enable running-mean " - "normalization, set 'observation_filter': 'MeanStdFilter'. " - "You can ignore this message if your environment doesn't " - "require observation normalization.") - prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - if "kl" in fetches: - # single-agent - self.local_evaluator.for_policy( - lambda pi: pi.update_kl(fetches["kl"])) - else: - - def update(pi, pi_id): - if pi_id in fetches: - pi.update_kl(fetches[pi_id]["kl"]) - else: - logger.debug( - "No data for {}, not updating kl".format(pi_id)) - - # multi-agent - self.local_evaluator.foreach_trainable_policy(update) - res = self.collect_metrics() - res.update( - timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=res.get("info", {})) - - # Warn about bad clipping configs - if self.config["vf_clip_param"] <= 0: - rew_scale = float("inf") - elif res["policy_reward_mean"]: - rew_scale = 0 # punt on handling multiagent case - else: - rew_scale = round( - abs(res["episode_reward_mean"]) / self.config["vf_clip_param"], - 0) - if rew_scale > 200: - logger.warning( - "The magnitude of your environment rewards are more than " - "{}x the scale of `vf_clip_param`. ".format(rew_scale) + - "This means that it will take more than " - "{} iterations for your value ".format(rew_scale) + - "function to converge. If this is not intended, consider " - "increasing `vf_clip_param`.") - return res - - def _validate_config(self): - if self.config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: - raise ValueError( - "Minibatch size {} must be <= train batch size {}.".format( - self.config["sgd_minibatch_size"], - self.config["train_batch_size"])) - if (self.config["batch_mode"] == "truncate_episodes" - and not self.config["use_gae"]): - raise ValueError( - "Episode truncation is not supported without a value " - "function. Consider setting batch_mode=complete_episodes.") - if (self.config["multiagent"]["policy_graphs"] - and not self.config["simple_optimizer"]): - logger.info( - "In multi-agent mode, policies will be optimized sequentially " - "by the multi-GPU optimizer. Consider setting " - "simple_optimizer=True if this doesn't work for you.") - if not self.config["vf_share_layers"]: - logger.warning( - "FYI: By default, the value function will not share layers " - "with the policy model ('vf_share_layers': False).") +def choose_policy_optimizer(local_evaluator, remote_evaluators, config): + if config["simple_optimizer"]: + return SyncSamplesOptimizer( + local_evaluator, + remote_evaluators, + num_sgd_iter=config["num_sgd_iter"], + train_batch_size=config["train_batch_size"]) + + return LocalMultiGPUOptimizer( + local_evaluator, + remote_evaluators, + sgd_batch_size=config["sgd_minibatch_size"], + num_sgd_iter=config["num_sgd_iter"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + num_envs_per_worker=config["num_envs_per_worker"], + train_batch_size=config["train_batch_size"], + standardize_fields=["advantages"], + straggler_mitigation=config["straggler_mitigation"]) + + +def update_kl(trainer, fetches): + if "kl" in fetches: + # single-agent + trainer.local_evaluator.for_policy( + lambda pi: pi.update_kl(fetches["kl"])) + else: + + def update(pi, pi_id): + if pi_id in fetches: + pi.update_kl(fetches[pi_id]["kl"]) + else: + logger.debug("No data for {}, not updating kl".format(pi_id)) + + # multi-agent + trainer.local_evaluator.foreach_trainable_policy(update) + + +def warn_about_obs_filter(trainer): + if "observation_filter" not in trainer.raw_user_config: + # TODO(ekl) remove this message after a few releases + logger.info( + "Important! Since 0.7.0, observation normalization is no " + "longer enabled by default. To enable running-mean " + "normalization, set 'observation_filter': 'MeanStdFilter'. " + "You can ignore this message if your environment doesn't " + "require observation normalization.") + + +def warn_about_bad_reward_scales(trainer, result): + # Warn about bad clipping configs + if trainer.config["vf_clip_param"] <= 0: + rew_scale = float("inf") + elif result["policy_reward_mean"]: + rew_scale = 0 # punt on handling multiagent case + else: + rew_scale = round( + abs(result["episode_reward_mean"]) / + trainer.config["vf_clip_param"], 0) + if rew_scale > 200: + logger.warning( + "The magnitude of your environment rewards are more than " + "{}x the scale of `vf_clip_param`. ".format(rew_scale) + + "This means that it will take more than " + "{} iterations for your value ".format(rew_scale) + + "function to converge. If this is not intended, consider " + "increasing `vf_clip_param`.") + + +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") + if config["sgd_minibatch_size"] > config["train_batch_size"]: + raise ValueError( + "Minibatch size {} must be <= train batch size {}.".format( + config["sgd_minibatch_size"], config["train_batch_size"])) + if (config["batch_mode"] == "truncate_episodes" and not config["use_gae"]): + raise ValueError( + "Episode truncation is not supported without a value " + "function. Consider setting batch_mode=complete_episodes.") + if (config["multiagent"]["policies"] and not config["simple_optimizer"]): + logger.info( + "In multi-agent mode, policies will be optimized sequentially " + "by the multi-GPU optimizer. Consider setting " + "simple_optimizer=True if this doesn't work for you.") + if not config["vf_share_layers"]: + logger.warning( + "FYI: By default, the value function will not share layers " + "with the policy model ('vf_share_layers': False).") + + +PPOTrainer = build_trainer( + name="PPOTrainer", + default_config=DEFAULT_CONFIG, + default_policy=PPOTFPolicy, + make_policy_optimizer=choose_policy_optimizer, + validate_config=validate_config, + after_optimizer_step=update_kl, + before_train_step=warn_about_obs_filter, + after_train_result=warn_about_bad_reward_scales) diff --git a/python/ray/rllib/agents/ppo/ppo_policy.py b/python/ray/rllib/agents/ppo/ppo_policy.py new file mode 100644 index 000000000000..5a17d6c6d60c --- /dev/null +++ b/python/ray/rllib/agents/ppo/ppo_policy.py @@ -0,0 +1,283 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import ray +from ray.rllib.evaluation.postprocessing import compute_advantages, \ + Postprocessing +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import LearningRateSchedule +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +logger = logging.getLogger(__name__) + +# Frozen logits of the policy that computed the action +BEHAVIOUR_LOGITS = "behaviour_logits" + + +class PPOLoss(object): + def __init__(self, + action_space, + value_targets, + advantages, + actions, + logits, + vf_preds, + curr_action_dist, + value_fn, + cur_kl_coeff, + valid_mask, + entropy_coeff=0, + clip_param=0.1, + vf_clip_param=0.1, + vf_loss_coeff=1.0, + use_gae=True): + """Constructs the loss for Proximal Policy Objective. + + Arguments: + action_space: Environment observation space specification. + value_targets (Placeholder): Placeholder for target values; used + for GAE. + actions (Placeholder): Placeholder for actions taken + from previous model evaluation. + advantages (Placeholder): Placeholder for calculated advantages + from previous model evaluation. + logits (Placeholder): Placeholder for logits output from + previous model evaluation. + vf_preds (Placeholder): Placeholder for value function output + from previous model evaluation. + curr_action_dist (ActionDistribution): ActionDistribution + of the current model. + value_fn (Tensor): Current value function output Tensor. + cur_kl_coeff (Variable): Variable holding the current PPO KL + coefficient. + valid_mask (Tensor): A bool mask of valid input elements (#2992). + entropy_coeff (float): Coefficient of the entropy regularizer. + clip_param (float): Clip parameter + vf_clip_param (float): Clip parameter for the value function + vf_loss_coeff (float): Coefficient of the value function loss + use_gae (bool): If true, use the Generalized Advantage Estimator. + """ + + def reduce_mean_valid(t): + return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) + + dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) + prev_dist = dist_cls(logits) + # Make loss functions. + logp_ratio = tf.exp( + curr_action_dist.logp(actions) - prev_dist.logp(actions)) + action_kl = prev_dist.kl(curr_action_dist) + self.mean_kl = reduce_mean_valid(action_kl) + + curr_entropy = curr_action_dist.entropy() + self.mean_entropy = reduce_mean_valid(curr_entropy) + + surrogate_loss = tf.minimum( + advantages * logp_ratio, + advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, + 1 + clip_param)) + self.mean_policy_loss = reduce_mean_valid(-surrogate_loss) + + if use_gae: + vf_loss1 = tf.square(value_fn - value_targets) + vf_clipped = vf_preds + tf.clip_by_value( + value_fn - vf_preds, -vf_clip_param, vf_clip_param) + vf_loss2 = tf.square(vf_clipped - value_targets) + vf_loss = tf.maximum(vf_loss1, vf_loss2) + self.mean_vf_loss = reduce_mean_valid(vf_loss) + loss = reduce_mean_valid( + -surrogate_loss + cur_kl_coeff * action_kl + + vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy) + else: + self.mean_vf_loss = tf.constant(0.0) + loss = reduce_mean_valid(-surrogate_loss + + cur_kl_coeff * action_kl - + entropy_coeff * curr_entropy) + self.loss = loss + + +def ppo_surrogate_loss(policy, batch_tensors): + if policy.model.state_in: + max_seq_len = tf.reduce_max(policy.model.seq_lens) + mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like( + batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool) + + policy.loss_obj = PPOLoss( + policy.action_space, + batch_tensors[Postprocessing.VALUE_TARGETS], + batch_tensors[Postprocessing.ADVANTAGES], + batch_tensors[SampleBatch.ACTIONS], + batch_tensors[BEHAVIOUR_LOGITS], + batch_tensors[SampleBatch.VF_PREDS], + policy.action_dist, + policy.value_function, + policy.kl_coeff, + mask, + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"], + vf_clip_param=policy.config["vf_clip_param"], + vf_loss_coeff=policy.config["vf_loss_coeff"], + use_gae=policy.config["use_gae"]) + + return policy.loss_obj.loss + + +def kl_and_loss_stats(policy, batch_tensors): + policy.explained_variance = explained_variance( + batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function) + + stats_fetches = { + "cur_kl_coeff": policy.kl_coeff, + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, + "vf_explained_var": policy.explained_variance, + "kl": policy.loss_obj.mean_kl, + "entropy": policy.loss_obj.mean_entropy, + } + + return stats_fetches + + +def vf_preds_and_logits_fetches(policy): + """Adds value function and logits outputs to experience batches.""" + return { + SampleBatch.VF_PREDS: policy.value_function, + BEHAVIOUR_LOGITS: policy.model.outputs, + } + + +def postprocess_ppo_gae(policy, + sample_batch, + other_agent_batches=None, + episode=None): + """Adds the policy logits, VF preds, and advantages to the trajectory.""" + + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 + else: + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"]) + return batch + + +def clip_gradients(policy, optimizer, loss): + if policy.config["grad_clip"] is not None: + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, + policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + else: + return optimizer.compute_gradients( + loss, colocate_gradients_with_ops=True) + + +class KLCoeffMixin(object): + def __init__(self, config): + # KL Coefficient + self.kl_coeff_val = config["kl_coeff"] + self.kl_target = config["kl_target"] + self.kl_coeff = tf.get_variable( + initializer=tf.constant_initializer(self.kl_coeff_val), + name="kl_coeff", + shape=(), + trainable=False, + dtype=tf.float32) + + def update_kl(self, sampled_kl): + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff_val *= 1.5 + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff_val *= 0.5 + self.kl_coeff.load(self.kl_coeff_val, session=self._sess) + return self.kl_coeff_val + + +class ValueNetworkMixin(object): + def __init__(self, obs_space, action_space, config): + if config["use_gae"]: + if config["vf_share_layers"]: + self.value_function = self.model.value_function() + else: + vf_config = config["model"].copy() + # Do not split the last layer of the value function into + # mean parameters and standard deviation parameters and + # do not make the standard deviations free variables. + vf_config["free_log_std"] = False + if vf_config["use_lstm"]: + vf_config["use_lstm"] = False + logger.warning( + "It is not recommended to use a LSTM model with " + "vf_share_layers=False (consider setting it to True). " + "If you want to not share layers, you can implement " + "a custom LSTM model that overrides the " + "value_function() method.") + with tf.variable_scope("value_function"): + self.value_function = ModelCatalog.get_model({ + "obs": self._obs_input, + "prev_actions": self._prev_action_input, + "prev_rewards": self._prev_reward_input, + "is_training": self._get_is_training_placeholder(), + }, obs_space, action_space, 1, vf_config).outputs + self.value_function = tf.reshape(self.value_function, [-1]) + else: + self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1]) + + def _value(self, ob, prev_action, prev_reward, *args): + feed_dict = { + self._obs_input: [ob], + self._prev_action_input: [prev_action], + self._prev_reward_input: [prev_reward], + self.model.seq_lens: [1] + } + assert len(args) == len(self.model.state_in), \ + (args, self.model.state_in) + for k, v in zip(self.model.state_in, args): + feed_dict[k] = v + vf = self._sess.run(self.value_function, feed_dict) + return vf[0] + + +def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + + +PPOTFPolicy = build_tf_policy( + name="PPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, + loss_fn=ppo_surrogate_loss, + stats_fn=kl_and_loss_stats, + extra_action_fetches_fn=vf_preds_and_logits_fetches, + postprocess_fn=postprocess_ppo_gae, + gradients_fn=clip_gradients, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin]) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py deleted file mode 100644 index fdaae4555789..000000000000 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ /dev/null @@ -1,369 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import logging -import tensorflow as tf - -import ray -from ray.rllib.evaluation.postprocessing import compute_advantages, \ - Postprocessing -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ - LearningRateSchedule -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override -from ray.rllib.utils.explained_variance import explained_variance - -logger = logging.getLogger(__name__) - -# Frozen logits of the policy that computed the action -BEHAVIOUR_LOGITS = "behaviour_logits" - - -class PPOLoss(object): - def __init__(self, - action_space, - value_targets, - advantages, - actions, - logits, - vf_preds, - curr_action_dist, - value_fn, - cur_kl_coeff, - valid_mask, - entropy_coeff=0, - clip_param=0.1, - vf_clip_param=0.1, - vf_loss_coeff=1.0, - use_gae=True): - """Constructs the loss for Proximal Policy Objective. - - Arguments: - action_space: Environment observation space specification. - value_targets (Placeholder): Placeholder for target values; used - for GAE. - actions (Placeholder): Placeholder for actions taken - from previous model evaluation. - advantages (Placeholder): Placeholder for calculated advantages - from previous model evaluation. - logits (Placeholder): Placeholder for logits output from - previous model evaluation. - vf_preds (Placeholder): Placeholder for value function output - from previous model evaluation. - curr_action_dist (ActionDistribution): ActionDistribution - of the current model. - value_fn (Tensor): Current value function output Tensor. - cur_kl_coeff (Variable): Variable holding the current PPO KL - coefficient. - valid_mask (Tensor): A bool mask of valid input elements (#2992). - entropy_coeff (float): Coefficient of the entropy regularizer. - clip_param (float): Clip parameter - vf_clip_param (float): Clip parameter for the value function - vf_loss_coeff (float): Coefficient of the value function loss - use_gae (bool): If true, use the Generalized Advantage Estimator. - """ - - def reduce_mean_valid(t): - return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) - - dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) - prev_dist = dist_cls(logits) - # Make loss functions. - logp_ratio = tf.exp( - curr_action_dist.logp(actions) - prev_dist.logp(actions)) - action_kl = prev_dist.kl(curr_action_dist) - self.mean_kl = reduce_mean_valid(action_kl) - - curr_entropy = curr_action_dist.entropy() - self.mean_entropy = reduce_mean_valid(curr_entropy) - - surrogate_loss = tf.minimum( - advantages * logp_ratio, - advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, - 1 + clip_param)) - self.mean_policy_loss = reduce_mean_valid(-surrogate_loss) - - if use_gae: - vf_loss1 = tf.square(value_fn - value_targets) - vf_clipped = vf_preds + tf.clip_by_value( - value_fn - vf_preds, -vf_clip_param, vf_clip_param) - vf_loss2 = tf.square(vf_clipped - value_targets) - vf_loss = tf.maximum(vf_loss1, vf_loss2) - self.mean_vf_loss = reduce_mean_valid(vf_loss) - loss = reduce_mean_valid( - -surrogate_loss + cur_kl_coeff * action_kl + - vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy) - else: - self.mean_vf_loss = tf.constant(0.0) - loss = reduce_mean_valid(-surrogate_loss + - cur_kl_coeff * action_kl - - entropy_coeff * curr_entropy) - self.loss = loss - - -class PPOPostprocessing(object): - """Adds the policy logits, VF preds, and advantages to the trajectory.""" - - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - return dict( - TFPolicyGraph.extra_compute_action_fetches(self), **{ - SampleBatch.VF_PREDS: self.value_function, - BEHAVIOUR_LOGITS: self.logits - }) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch["dones"][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - batch = compute_advantages( - sample_batch, - last_r, - self.config["gamma"], - self.config["lambda"], - use_gae=self.config["use_gae"]) - return batch - - -class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph): - def __init__(self, - observation_space, - action_space, - config, - existing_inputs=None): - """ - Arguments: - observation_space: Environment observation space specification. - action_space: Environment action space specification. - config (dict): Configuration values for PPO graph. - existing_inputs (list): Optional list of tuples that specify the - placeholders upon which the graph should be built upon. - """ - config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) - self.sess = tf.get_default_session() - self.action_space = action_space - self.config = config - self.kl_coeff_val = self.config["kl_coeff"] - self.kl_target = self.config["kl_target"] - dist_cls, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - - if existing_inputs: - obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \ - existing_inputs[:8] - existing_state_in = existing_inputs[8:-1] - existing_seq_lens = existing_inputs[-1] - else: - obs_ph = tf.placeholder( - tf.float32, - name="obs", - shape=(None, ) + observation_space.shape) - adv_ph = tf.placeholder( - tf.float32, name="advantages", shape=(None, )) - act_ph = ModelCatalog.get_action_placeholder(action_space) - logits_ph = tf.placeholder( - tf.float32, name="logits", shape=(None, logit_dim)) - vf_preds_ph = tf.placeholder( - tf.float32, name="vf_preds", shape=(None, )) - value_targets_ph = tf.placeholder( - tf.float32, name="value_targets", shape=(None, )) - prev_actions_ph = ModelCatalog.get_action_placeholder(action_space) - prev_rewards_ph = tf.placeholder( - tf.float32, [None], name="prev_reward") - existing_state_in = None - existing_seq_lens = None - self.observations = obs_ph - self.prev_actions = prev_actions_ph - self.prev_rewards = prev_rewards_ph - - self.loss_in = [ - (SampleBatch.CUR_OBS, obs_ph), - (Postprocessing.VALUE_TARGETS, value_targets_ph), - (Postprocessing.ADVANTAGES, adv_ph), - (SampleBatch.ACTIONS, act_ph), - (BEHAVIOUR_LOGITS, logits_ph), - (SampleBatch.VF_PREDS, vf_preds_ph), - (SampleBatch.PREV_ACTIONS, prev_actions_ph), - (SampleBatch.PREV_REWARDS, prev_rewards_ph), - ] - self.model = ModelCatalog.get_model( - { - "obs": obs_ph, - "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph, - "is_training": self._get_is_training_placeholder(), - }, - observation_space, - action_space, - logit_dim, - self.config["model"], - state_in=existing_state_in, - seq_lens=existing_seq_lens) - - # KL Coefficient - self.kl_coeff = tf.get_variable( - initializer=tf.constant_initializer(self.kl_coeff_val), - name="kl_coeff", - shape=(), - trainable=False, - dtype=tf.float32) - - self.logits = self.model.outputs - curr_action_dist = dist_cls(self.logits) - self.sampler = curr_action_dist.sample() - if self.config["use_gae"]: - if self.config["vf_share_layers"]: - self.value_function = self.model.value_function() - else: - vf_config = self.config["model"].copy() - # Do not split the last layer of the value function into - # mean parameters and standard deviation parameters and - # do not make the standard deviations free variables. - vf_config["free_log_std"] = False - if vf_config["use_lstm"]: - vf_config["use_lstm"] = False - logger.warning( - "It is not recommended to use a LSTM model with " - "vf_share_layers=False (consider setting it to True). " - "If you want to not share layers, you can implement " - "a custom LSTM model that overrides the " - "value_function() method.") - with tf.variable_scope("value_function"): - self.value_function = ModelCatalog.get_model({ - "obs": obs_ph, - "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph, - "is_training": self._get_is_training_placeholder(), - }, observation_space, action_space, 1, vf_config).outputs - self.value_function = tf.reshape(self.value_function, [-1]) - else: - self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) - - if self.model.state_in: - max_seq_len = tf.reduce_max(self.model.seq_lens) - mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(adv_ph, dtype=tf.bool) - - self.loss_obj = PPOLoss( - action_space, - value_targets_ph, - adv_ph, - act_ph, - logits_ph, - vf_preds_ph, - curr_action_dist, - self.value_function, - self.kl_coeff, - mask, - entropy_coeff=self.config["entropy_coeff"], - clip_param=self.config["clip_param"], - vf_clip_param=self.config["vf_clip_param"], - vf_loss_coeff=self.config["vf_loss_coeff"], - use_gae=self.config["use_gae"]) - - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - self.state_values = self.value_function - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=obs_ph, - action_sampler=self.sampler, - action_prob=curr_action_dist.sampled_action_prob(), - loss=self.loss_obj.loss, - model=self.model, - loss_inputs=self.loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions_ph, - prev_reward_input=prev_rewards_ph, - seq_lens=self.model.seq_lens, - max_seq_len=config["model"]["max_seq_len"]) - - self.sess.run(tf.global_variables_initializer()) - self.explained_variance = explained_variance(value_targets_ph, - self.value_function) - self.stats_fetches = { - "cur_kl_coeff": self.kl_coeff, - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "total_loss": self.loss_obj.loss, - "policy_loss": self.loss_obj.mean_policy_loss, - "vf_loss": self.loss_obj.mean_vf_loss, - "vf_explained_var": self.explained_variance, - "kl": self.loss_obj.mean_kl, - "entropy": self.loss_obj.mean_entropy - } - - @override(TFPolicyGraph) - def copy(self, existing_inputs): - """Creates a copy of self using existing input placeholders.""" - return PPOPolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) - - @override(TFPolicyGraph) - def gradients(self, optimizer, loss): - if self.config["grad_clip"] is not None: - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, - self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - else: - return optimizer.compute_gradients( - loss, colocate_gradients_with_ops=True) - - @override(PolicyGraph) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicyGraph) - def extra_compute_grad_fetches(self): - return {LEARNER_STATS_KEY: self.stats_fetches} - - def update_kl(self, sampled_kl): - if sampled_kl > 2.0 * self.kl_target: - self.kl_coeff_val *= 1.5 - elif sampled_kl < 0.5 * self.kl_target: - self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self.sess) - return self.kl_coeff_val - - def _value(self, ob, prev_action, prev_reward, *args): - feed_dict = { - self.observations: [ob], - self.prev_actions: [prev_action], - self.prev_rewards: [prev_reward], - self.model.seq_lens: [1] - } - assert len(args) == len(self.model.state_in), \ - (args, self.model.state_in) - for k, v in zip(self.model.state_in, args): - feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) - return vf[0] diff --git a/python/ray/rllib/agents/ppo/test/test.py b/python/ray/rllib/agents/ppo/test/test.py index 432b22f9aed2..1091b639c6f4 100644 --- a/python/ray/rllib/agents/ppo/test/test.py +++ b/python/ray/rllib/agents/ppo/test/test.py @@ -4,11 +4,13 @@ import unittest import numpy as np -import tensorflow as tf from numpy.testing import assert_allclose from ray.rllib.models.action_dist import Categorical from ray.rllib.agents.ppo.utils import flatten, concatenate +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() # TODO(ekl): move to rllib/models dir diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py index 420a567d8eff..2ad6a3e56f95 100644 --- a/python/ray/rllib/agents/qmix/qmix.py +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -4,7 +4,7 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import DQNTrainer -from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph +from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy # yapf: disable # __sphinx_doc_begin__ @@ -95,7 +95,7 @@ class QMixTrainer(DQNTrainer): _name = "QMIX" _default_config = DEFAULT_CONFIG - _policy_graph = QMixPolicyGraph + _policy = QMixTorchPolicy _optimizer_shared_configs = [ "learning_starts", "buffer_size", "train_batch_size" ] diff --git a/python/ray/rllib/agents/qmix/qmix_policy_graph.py b/python/ray/rllib/agents/qmix/qmix_policy.py similarity index 98% rename from python/ray/rllib/agents/qmix/qmix_policy_graph.py rename to python/ray/rllib/agents/qmix/qmix_policy.py index b7c9a7ad8120..26ec387de004 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy_graph.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -14,8 +14,8 @@ from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.action_dist import TupleActions from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.lstm import chop_into_sequences @@ -130,7 +130,7 @@ def forward(self, rewards, actions, terminated, mask, obs, next_obs, return loss, mask, masked_td_error, chosen_action_qvals, targets -class QMixPolicyGraph(PolicyGraph): +class QMixTorchPolicy(Policy): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents @@ -213,7 +213,7 @@ def __init__(self, obs_space, action_space, config): alpha=config["optim_alpha"], eps=config["optim_eps"]) - @override(PolicyGraph) + @override(Policy) def compute_actions(self, obs_batch, state_batches=None, @@ -243,7 +243,7 @@ def compute_actions(self, return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} - @override(PolicyGraph) + @override(Policy) def learn_on_batch(self, samples): obs_batch, action_mask = self._unpack_observation( samples[SampleBatch.CUR_OBS]) @@ -314,22 +314,22 @@ def to_batches(arr): } return {LEARNER_STATS_KEY: stats} - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return [ s.expand([self.n_agents, -1]).numpy() for s in self.model.state_init() ] - @override(PolicyGraph) + @override(Policy) def get_weights(self): return {"model": self.model.state_dict()} - @override(PolicyGraph) + @override(Policy) def set_weights(self, weights): self.model.load_state_dict(weights["model"]) - @override(PolicyGraph) + @override(Policy) def get_state(self): return { "model": self.model.state_dict(), @@ -340,7 +340,7 @@ def get_state(self): "cur_epsilon": self.cur_epsilon, } - @override(PolicyGraph) + @override(Policy) def set_state(self, state): self.model.load_state_dict(state["model"]) self.target_model.load_state_dict(state["target_model"]) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 3b92cff22494..d49eecbf9d2f 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -2,35 +2,38 @@ from __future__ import division from __future__ import print_function -from datetime import datetime import copy import logging import os import pickle -import six -import time import tempfile -import tensorflow as tf +import time +from datetime import datetime from types import FunctionType import ray +import six from ray.exceptions import RayError -from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ - ShuffledInput -from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ _validate_multiagent_config -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ + ShuffledInput from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import FilterManager, deep_update, merge_dicts +from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.memory import ray_get_and_free +from ray.tune.logger import UnifiedLogger +from ray.tune.logger import to_tf_values from ray.tune.registry import ENV_CREATOR, register_env, _global_registry +from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.trainable import Trainable from ray.tune.trial import Resources, ExportFormat -from ray.tune.logger import UnifiedLogger -from ray.tune.result import DEFAULT_RESULTS_DIR + +tf = try_import_tf() logger = logging.getLogger(__name__) @@ -54,11 +57,11 @@ # metrics can be attached to the episode by updating the episode object's # custom metrics dict (see examples/custom_metrics_and_callbacks.py). "callbacks": { - "on_episode_start": None, # arg: {"env": .., "episode": ...} - "on_episode_step": None, # arg: {"env": .., "episode": ...} - "on_episode_end": None, # arg: {"env": .., "episode": ...} - "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} - "on_train_result": None, # arg: {"trainer": ..., "result": ...} + "on_episode_start": None, # arg: {"env": .., "episode": ...} + "on_episode_step": None, # arg: {"env": .., "episode": ...} + "on_episode_end": None, # arg: {"env": .., "episode": ...} + "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} + "on_train_result": None, # arg: {"trainer": ..., "result": ...} "on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...} }, # Whether to attempt to continue training if a worker crashes. @@ -106,7 +109,10 @@ # and to disable exploration by computing deterministic actions # TODO(kismuz): implement determ. actions and include relevant keys hints "evaluation_config": { - "beholder": False + "beholder": False, + "should_log_histograms": False, + "to_tf_values": to_tf_values, + "debug_learner_session_port": None, }, # === Resources === @@ -220,15 +226,17 @@ # === Multiagent === "multiagent": { - # Map from policy ids to tuples of (policy_graph_cls, obs_space, + # Map from policy ids to tuples of (policy_cls, obs_space, # act_space, config). See policy_evaluator.py for more info. - "policy_graphs": {}, + "policies": {}, # Function mapping agent ids to policy ids. "policy_mapping_fn": None, # Optional whitelist of policies to train, or None for all policies. "policies_to_train": None, }, } + + # __sphinx_doc_end__ # yapf: enable @@ -414,8 +422,13 @@ def _setup(self, config): if self.config.get("log_level"): logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) - # TODO(ekl) setting the graph is unnecessary for PyTorch agents - with tf.Graph().as_default(): + def get_scope(): + if tf: + return tf.Graph().as_default() + else: + return open("/dev/null") # fake a no-op scope + + with get_scope(): self._init(self.config, self.env_creator) # Evaluation related @@ -430,9 +443,7 @@ def _setup(self, config): "using evaluation_config: {}".format(extra_config)) # Make local evaluation evaluators self.evaluation_ev = self.make_local_evaluator( - self.env_creator, - self._policy_graph, - extra_config=extra_config) + self.env_creator, self._policy, extra_config=extra_config) self.evaluation_metrics = self._evaluate() @override(Trainable) @@ -573,10 +584,10 @@ def _default_config(self): @PublicAPI def get_policy(self, policy_id=DEFAULT_POLICY_ID): - """Return policy graph for the specified id, or None. + """Return policy for the specified id, or None. Arguments: - policy_id (str): id of policy graph to return. + policy_id (str): id of policy to return. """ return self.local_evaluator.get_policy(policy_id) @@ -601,28 +612,25 @@ def set_weights(self, weights): self.local_evaluator.set_weights(weights) @DeveloperAPI - def make_local_evaluator(self, - env_creator, - policy_graph, - extra_config=None): + def make_local_evaluator(self, env_creator, policy, extra_config=None): """Convenience method to return configured local evaluator.""" return self._make_evaluator( PolicyEvaluator, env_creator, - policy_graph, + policy, 0, merge_dicts( # important: allow local tf to use more CPUs for optimization merge_dicts( self.config, { "tf_session_args": self. - config["local_evaluator_tf_session_args"] + config["local_evaluator_tf_session_args"] }), extra_config or {})) @DeveloperAPI - def make_remote_evaluators(self, env_creator, policy_graph, count): + def make_remote_evaluators(self, env_creator, policy, count): """Convenience method to return a number of remote evaluators.""" remote_args = { @@ -634,8 +642,8 @@ def make_remote_evaluators(self, env_creator, policy_graph, count): cls = PolicyEvaluator.as_remote(**remote_args).remote return [ - self._make_evaluator(cls, env_creator, policy_graph, i + 1, - self.config) for i in range(count) + self._make_evaluator(cls, env_creator, policy, i + 1, self.config) + for i in range(count) ] @DeveloperAPI @@ -695,6 +703,13 @@ def resource_help(cls, config): @staticmethod def _validate_config(config): + if "policy_graphs" in config["multiagent"]: + logger.warning( + "The `policy_graphs` config has been renamed to `policies`.") + # Backwards compatibility + config["multiagent"]["policies"] = config["multiagent"][ + "policy_graphs"] + del config["multiagent"]["policy_graphs"] if "gpu" in config: raise ValueError( "The `gpu` config is deprecated, please use `num_gpus=0|1` " @@ -755,8 +770,7 @@ def _has_policy_optimizer(self): return hasattr(self, "optimizer") and isinstance( self.optimizer, PolicyOptimizer) - def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, - config): + def _make_evaluator(self, cls, env_creator, policy, worker_index, config): def session_creator(): logger.debug("Creating TF session {}".format( config["tf_session_args"])) @@ -798,18 +812,18 @@ def session_creator(): else: input_evaluation = config["input_evaluation"] - # Fill in the default policy graph if 'None' is specified in multiagent - if self.config["multiagent"]["policy_graphs"]: - tmp = self.config["multiagent"]["policy_graphs"] + # Fill in the default policy if 'None' is specified in multiagent + if self.config["multiagent"]["policies"]: + tmp = self.config["multiagent"]["policies"] _validate_multiagent_config(tmp, allow_none_graph=True) for k, v in tmp.items(): if v[0] is None: - tmp[k] = (policy_graph, v[1], v[2], v[3]) - policy_graph = tmp + tmp[k] = (policy, v[1], v[2], v[3]) + policy = tmp return cls( env_creator, - policy_graph, + policy, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], policies_to_train=self.config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py new file mode 100644 index 000000000000..aae8e35f64f8 --- /dev/null +++ b/python/ray/rllib/agents/trainer_template.py @@ -0,0 +1,96 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG +from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_trainer(name, + default_policy, + default_config=None, + make_policy_optimizer=None, + validate_config=None, + get_policy_class=None, + before_train_step=None, + after_optimizer_step=None, + after_train_result=None): + """Helper function for defining a custom trainer. + + Arguments: + name (str): name of the trainer (e.g., "PPO") + default_policy (cls): the default Policy class to use + default_config (dict): the default config dict of the algorithm, + otherwises uses the Trainer default config + make_policy_optimizer (func): optional function that returns a + PolicyOptimizer instance given + (local_evaluator, remote_evaluators, config) + validate_config (func): optional callback that checks a given config + for correctness. It may mutate the config as needed. + get_policy_class (func): optional callback that takes a config and + returns the policy class to override the default with + before_train_step (func): optional callback to run before each train() + call. It takes the trainer instance as an argument. + after_optimizer_step (func): optional callback to run after each + step() call to the policy optimizer. It takes the trainer instance + and the policy gradient fetches as arguments. + after_train_result (func): optional callback to run at the end of each + train() call. It takes the trainer instance and result dict as + arguments, and may mutate the result dict as needed. + + Returns: + a Trainer instance that uses the specified args. + """ + + if not name.endswith("Trainer"): + raise ValueError("Algorithm name should have *Trainer suffix", name) + + class trainer_cls(Trainer): + _name = name + _default_config = default_config or COMMON_CONFIG + _policy = default_policy + + def _init(self, config, env_creator): + if validate_config: + validate_config(config) + if get_policy_class is None: + policy = default_policy + else: + policy = get_policy_class(config) + self.local_evaluator = self.make_local_evaluator( + env_creator, policy) + self.remote_evaluators = self.make_remote_evaluators( + env_creator, policy, config["num_workers"]) + if make_policy_optimizer: + self.optimizer = make_policy_optimizer( + self.local_evaluator, self.remote_evaluators, config) + else: + optimizer_config = dict( + config["optimizer"], + **{"train_batch_size": config["train_batch_size"]}) + self.optimizer = SyncSamplesOptimizer(self.local_evaluator, + self.remote_evaluators, + **optimizer_config) + + @override(Trainer) + def _train(self): + if before_train_step: + before_train_step(self) + prev_steps = self.optimizer.num_steps_sampled + fetches = self.optimizer.step() + if after_optimizer_step: + after_optimizer_step(self, fetches) + res = self.collect_metrics() + res.update( + timesteps_this_iter=self.optimizer.num_steps_sampled - + prev_steps, + info=res.get("info", {})) + if after_train_result: + after_train_result(self, res) + return res + + trainer_cls.__name__ = name + trainer_cls.__qualname__ = name + return trainer_cls diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index b7afa222b149..8d7641b9c313 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -27,7 +27,7 @@ class MultiAgentEpisode(object): user_data (dict): Dict that you can use for temporary storage. Use case 1: Model-based rollouts in multi-agent: - A custom compute_actions() function in a policy graph can inspect the + A custom compute_actions() function in a policy can inspect the current episode state and perform a number of rollouts based on the policies and state of other agents in the environment. @@ -80,7 +80,7 @@ def soft_reset(self): @DeveloperAPI def policy_for(self, agent_id=_DUMMY_AGENT_ID): - """Returns the policy graph for the specified agent. + """Returns the policy for the specified agent. If the agent is new, the policy mapping fn will be called to bind the agent to a policy for the duration of the episode. diff --git a/python/ray/rllib/evaluation/interface.py b/python/ray/rllib/evaluation/interface.py index eb705a99b530..6bc626da1175 100644 --- a/python/ray/rllib/evaluation/interface.py +++ b/python/ray/rllib/evaluation/interface.py @@ -62,7 +62,7 @@ def compute_gradients(self, samples): Returns: (grads, info): A list of gradients that can be applied on a compatible evaluator. In the multi-agent case, returns a dict - of gradients keyed by policy graph ids. An info dictionary of + of gradients keyed by policy ids. An info dictionary of extra metadata is also returned. Examples: diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index fe43257226cf..d8b3122fed4b 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -7,21 +7,18 @@ import collections import ray -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) -# By convention, metrics from optimizing the loss can be reported in the -# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. -LEARNER_STATS_KEY = "learner_stats" - @DeveloperAPI def get_learner_stats(grad_info): - """Return optimization stats reported from the policy graph. + """Return optimization stats reported from the policy. Example: >>> grad_info = evaluator.learn_on_batch(samples) @@ -59,18 +56,23 @@ def collect_episodes(local_evaluator=None, timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" - pending = [ - a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators - ] - collected, _ = ray.wait( - pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) - num_metric_batches_dropped = len(pending) - len(collected) - if pending and len(collected) == 0: - raise ValueError( - "Timed out waiting for metrics from workers. You can configure " - "this timeout with `collect_metrics_timeout`.") - - metric_lists = ray_get_and_free(collected) + if remote_evaluators: + pending = [ + a.apply.remote(lambda ev: ev.get_metrics()) + for a in remote_evaluators + ] + collected, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) + num_metric_batches_dropped = len(pending) - len(collected) + if pending and len(collected) == 0: + raise ValueError( + "Timed out waiting for metrics from workers. You can " + "configure this timeout with `collect_metrics_timeout`.") + metric_lists = ray_get_and_free(collected) + else: + metric_lists = [] + num_metric_batches_dropped = 0 + if local_evaluator: metric_lists.append(local_evaluator.get_metrics()) episodes = [] diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index f8ebd54ca53c..ea13849d27e4 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -5,7 +5,6 @@ import gym import logging import pickle -import tensorflow as tf from tensorboard.plugins.beholder import Beholder import ray @@ -17,11 +16,10 @@ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.evaluation.interface import EvaluatorInterface -from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \ - DEFAULT_POLICY_ID from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator @@ -33,7 +31,9 @@ summarize, enable_periodic_logging from ray.rllib.utils.filter import get_filter from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf +tf = try_import_tf() logger = logging.getLogger(__name__) # Handle to the current evaluator, which will be set to the most recently @@ -52,9 +52,9 @@ def get_global_evaluator(): @DeveloperAPI class PolicyEvaluator(EvaluatorInterface): - """Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``. + """Common ``PolicyEvaluator`` implementation that wraps a ``Policy``. - This class wraps a policy graph instance and an environment class to + This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training. @@ -65,7 +65,7 @@ class PolicyEvaluator(EvaluatorInterface): >>> # Create a policy evaluator and using it to collect experiences. >>> evaluator = PolicyEvaluator( ... env_creator=lambda _: gym.make("CartPole-v0"), - ... policy_graph=PGPolicyGraph) + ... policy=PGTFPolicy) >>> print(evaluator.sample()) SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], @@ -76,7 +76,7 @@ class PolicyEvaluator(EvaluatorInterface): ... evaluator_cls=PolicyEvaluator, ... evaluator_args={ ... "env_creator": lambda _: gym.make("CartPole-v0"), - ... "policy_graph": PGPolicyGraph, + ... "policy": PGTFPolicy, ... }, ... num_workers=10) >>> for _ in range(10): optimizer.step() @@ -84,15 +84,15 @@ class PolicyEvaluator(EvaluatorInterface): >>> # Creating a multi-agent policy evaluator >>> evaluator = PolicyEvaluator( ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), - ... policy_graphs={ + ... policies={ ... # Use an ensemble of two policies for car agents ... "car_policy1": - ... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}), + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), ... "car_policy2": - ... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}), + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}), ... # Use a single shared policy for all traffic lights ... "traffic_light_policy": - ... (PGPolicyGraph, Box(...), Discrete(...), {}), + ... (PGTFPolicy, Box(...), Discrete(...), {}), ... }, ... policy_mapping_fn=lambda agent_id: ... random.choice(["car_policy1", "car_policy2"]) @@ -113,7 +113,7 @@ def as_remote(cls, num_cpus=None, num_gpus=None, resources=None): @DeveloperAPI def __init__(self, env_creator, - policy_graph, + policy, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, @@ -147,9 +147,9 @@ def __init__(self, Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. - policy_graph (class|dict): Either a class implementing - PolicyGraph, or a dictionary of policy id strings to - (PolicyGraph, obs_space, action_space, config) tuples. If a + policy (class|dict): Either a class implementing + Policy, or a dictionary of policy id strings to + (Policy, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to @@ -159,7 +159,7 @@ def __init__(self, policies_to_train (list): Optional whitelist of policies to train, or None for all policies. tf_session_creator (func): A function that returns a TF session. - This is optional and only useful with TFPolicyGraph. + This is optional and only useful with TFPolicy. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: @@ -196,7 +196,7 @@ def __init__(self, model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the - per-policy configs specified by `policy_graph`. + per-policy configs specified by `policy`. worker_index (int): For remote evaluators, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. @@ -302,7 +302,7 @@ def make_env(vector_index): vector_index=vector_index, remote=remote_worker_envs))) self.tf_sess = None - policy_dict = _validate_and_canonicalize(policy_graph, self.env) + policy_dict = _validate_and_canonicalize(policy, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) if _has_tensorflow_graph(policy_dict): if (ray.is_initialized() @@ -331,7 +331,7 @@ def make_env(vector_index): or isinstance(self.env, ExternalMultiAgentEnv)) or isinstance(self.env, BaseEnv)): raise ValueError( - "Have multiple policy graphs {}, but the env ".format( + "Have multiple policies {}, but the env ".format( self.policy_map) + "{} is not a subclass of BaseEnv, MultiAgentEnv or " "ExternalMultiAgentEnv?".format(self.env)) @@ -615,17 +615,17 @@ def foreach_env(self, func): @DeveloperAPI def get_policy(self, policy_id=DEFAULT_POLICY_ID): - """Return policy graph for the specified id, or None. + """Return policy for the specified id, or None. Arguments: - policy_id (str): id of policy graph to return. + policy_id (str): id of policy to return. """ return self.policy_map.get(policy_id) @DeveloperAPI def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): - """Apply the given function to the specified policy graph.""" + """Apply the given function to the specified policy.""" return func(self.policy_map[policy_id]) @@ -715,7 +715,7 @@ def _build_policy_map(self, policy_dict, policy_config): preprocessors = {} for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): - logger.debug("Creating policy graph for {}".format(name)) + logger.debug("Creating policy for {}".format(name)) merged_conf = merge_dicts(policy_config, conf) if self.preprocessing_enabled: preprocessor = ModelCatalog.get_preprocessor_for_space( @@ -727,10 +727,13 @@ def _build_policy_map(self, policy_dict, policy_config): if isinstance(obs_space, gym.spaces.Dict) or \ isinstance(obs_space, gym.spaces.Tuple): raise ValueError( - "Found raw Tuple|Dict space as input to policy graph. " + "Found raw Tuple|Dict space as input to policy. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") - with tf.variable_scope(name): + if tf: + with tf.variable_scope(name): + policy_map[name] = cls(obs_space, act_space, merged_conf) + else: policy_map[name] = cls(obs_space, act_space, merged_conf) if self.worker_index == 0: logger.info("Built policy map: {}".format(policy_map)) @@ -742,12 +745,12 @@ def __del__(self): self.sampler.shutdown = True -def _validate_and_canonicalize(policy_graph, env): - if isinstance(policy_graph, dict): - _validate_multiagent_config(policy_graph) - return policy_graph - elif not issubclass(policy_graph, PolicyGraph): - raise ValueError("policy_graph must be a rllib.PolicyGraph class") +def _validate_and_canonicalize(policy, env): + if isinstance(policy, dict): + _validate_multiagent_config(policy) + return policy + elif not issubclass(policy, Policy): + raise ValueError("policy must be a rllib.Policy class") else: if (isinstance(env, MultiAgentEnv) and not hasattr(env, "observation_space")): @@ -755,38 +758,35 @@ def _validate_and_canonicalize(policy_graph, env): "MultiAgentEnv must have observation_space defined if run " "in a single-agent configuration.") return { - DEFAULT_POLICY_ID: (policy_graph, env.observation_space, + DEFAULT_POLICY_ID: (policy, env.observation_space, env.action_space, {}) } -def _validate_multiagent_config(policy_graph, allow_none_graph=False): - for k, v in policy_graph.items(): +def _validate_multiagent_config(policy, allow_none_graph=False): + for k, v in policy.items(): if not isinstance(k, str): - raise ValueError("policy_graph keys must be strs, got {}".format( + raise ValueError("policy keys must be strs, got {}".format( type(k))) if not isinstance(v, tuple) or len(v) != 4: raise ValueError( - "policy_graph values must be tuples of " + "policy values must be tuples of " "(cls, obs_space, action_space, config), got {}".format(v)) if allow_none_graph and v[0] is None: pass - elif not issubclass(v[0], PolicyGraph): - raise ValueError( - "policy_graph tuple value 0 must be a rllib.PolicyGraph " - "class or None, got {}".format(v[0])) + elif not issubclass(v[0], Policy): + raise ValueError("policy tuple value 0 must be a rllib.Policy " + "class or None, got {}".format(v[0])) if not isinstance(v[1], gym.Space): raise ValueError( - "policy_graph tuple value 1 (observation_space) must be a " + "policy tuple value 1 (observation_space) must be a " "gym.Space, got {}".format(type(v[1]))) if not isinstance(v[2], gym.Space): - raise ValueError( - "policy_graph tuple value 2 (action_space) must be a " - "gym.Space, got {}".format(type(v[2]))) + raise ValueError("policy tuple value 2 (action_space) must be a " + "gym.Space, got {}".format(type(v[2]))) if not isinstance(v[3], dict): - raise ValueError( - "policy_graph tuple value 3 (config) must be a dict, " - "got {}".format(type(v[3]))) + raise ValueError("policy tuple value 3 (config) must be a dict, " + "got {}".format(type(v[3]))) def _validate_env(env): @@ -809,6 +809,6 @@ def _monitor(env, path): def _has_tensorflow_graph(policy_dict): for policy, _, _, _ in policy_dict.values(): - if issubclass(policy, TFPolicyGraph): + if issubclass(policy, TFPolicy): return True return False diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index a577550975f9..5d0fdf2a4e57 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -2,286 +2,7 @@ from __future__ import division from __future__ import print_function -import numpy as np -import gym +from ray.rllib.policy.policy import Policy +from ray.rllib.utils import renamed_class -from ray.rllib.utils.annotations import DeveloperAPI - - -@DeveloperAPI -class PolicyGraph(object): - """An agent policy and loss, i.e., a TFPolicyGraph or other subclass. - - This object defines how to act in the environment, and also losses used to - improve the policy based on its experiences. Note that both policy and - loss are defined together for convenience, though the policy itself is - logically separate. - - All policies can directly extend PolicyGraph, however TensorFlow users may - find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib - to apply TensorFlow-specific optimizations such as fusing multiple policy - graphs and multi-GPU support. - - Attributes: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - """ - - @DeveloperAPI - def __init__(self, observation_space, action_space, config): - """Initialize the graph. - - This is the standard constructor for policy graphs. The policy graph - class you pass into PolicyEvaluator will be constructed with - these arguments. - - Args: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - config (dict): Policy-specific configuration data. - """ - - self.observation_space = observation_space - self.action_space = action_space - - @DeveloperAPI - def compute_actions(self, - obs_batch, - state_batches, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - """Compute actions for the current policy. - - Arguments: - obs_batch (np.ndarray): batch of observations - state_batches (list): list of RNN state input batches, if any - prev_action_batch (np.ndarray): batch of previous action values - prev_reward_batch (np.ndarray): batch of previous rewards - info_batch (info): batch of info objects - episodes (list): MultiAgentEpisode for each obs in obs_batch. - This provides access to all of the internal episode state, - which may be useful for model-based or multiagent algorithms. - kwargs: forward compatibility placeholder - - Returns: - actions (np.ndarray): batch of output actions, with shape like - [BATCH_SIZE, ACTION_SHAPE]. - state_outs (list): list of RNN state output batches, if any, with - shape like [STATE_SIZE, BATCH_SIZE]. - info (dict): dictionary of extra feature batches, if any, with - shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. - """ - raise NotImplementedError - - @DeveloperAPI - def compute_single_action(self, - obs, - state, - prev_action=None, - prev_reward=None, - info=None, - episode=None, - clip_actions=False, - **kwargs): - """Unbatched version of compute_actions. - - Arguments: - obs (obj): single observation - state_batches (list): list of RNN state inputs, if any - prev_action (obj): previous action value, if any - prev_reward (int): previous reward, if any - info (dict): info object, if any - episode (MultiAgentEpisode): this provides access to all of the - internal episode state, which may be useful for model-based or - multi-agent algorithms. - clip_actions (bool): should the action be clipped - kwargs: forward compatibility placeholder - - Returns: - actions (obj): single action - state_outs (list): list of RNN state outputs, if any - info (dict): dictionary of extra features, if any - """ - - prev_action_batch = None - prev_reward_batch = None - info_batch = None - episodes = None - if prev_action is not None: - prev_action_batch = [prev_action] - if prev_reward is not None: - prev_reward_batch = [prev_reward] - if info is not None: - info_batch = [info] - if episode is not None: - episodes = [episode] - [action], state_out, info = self.compute_actions( - [obs], [[s] for s in state], - prev_action_batch=prev_action_batch, - prev_reward_batch=prev_reward_batch, - info_batch=info_batch, - episodes=episodes) - if clip_actions: - action = clip_action(action, self.action_space) - return action, [s[0] for s in state_out], \ - {k: v[0] for k, v in info.items()} - - @DeveloperAPI - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - """Implements algorithm-specific trajectory postprocessing. - - This will be called on each trajectory fragment computed during policy - evaluation. Each fragment is guaranteed to be only from one episode. - - Arguments: - sample_batch (SampleBatch): batch of experiences for the policy, - which will contain at most one episode trajectory. - other_agent_batches (dict): In a multi-agent env, this contains a - mapping of agent ids to (policy_graph, agent_batch) tuples - containing the policy graph and experiences of the other agent. - episode (MultiAgentEpisode): this provides access to all of the - internal episode state, which may be useful for model-based or - multi-agent algorithms. - - Returns: - SampleBatch: postprocessed sample batch. - """ - return sample_batch - - @DeveloperAPI - def learn_on_batch(self, samples): - """Fused compute gradients and apply gradients call. - - Either this or the combination of compute/apply grads must be - implemented by subclasses. - - Returns: - grad_info: dictionary of extra metadata from compute_gradients(). - - Examples: - >>> batch = ev.sample() - >>> ev.learn_on_batch(samples) - """ - - grads, grad_info = self.compute_gradients(samples) - self.apply_gradients(grads) - return grad_info - - @DeveloperAPI - def compute_gradients(self, postprocessed_batch): - """Computes gradients against a batch of experiences. - - Either this or learn_on_batch() must be implemented by subclasses. - - Returns: - grads (list): List of gradient output values - info (dict): Extra policy-specific values - """ - raise NotImplementedError - - @DeveloperAPI - def apply_gradients(self, gradients): - """Applies previously computed gradients. - - Either this or learn_on_batch() must be implemented by subclasses. - """ - raise NotImplementedError - - @DeveloperAPI - def get_weights(self): - """Returns model weights. - - Returns: - weights (obj): Serializable copy or view of model weights - """ - raise NotImplementedError - - @DeveloperAPI - def set_weights(self, weights): - """Sets model weights. - - Arguments: - weights (obj): Serializable copy or view of model weights - """ - raise NotImplementedError - - @DeveloperAPI - def get_initial_state(self): - """Returns initial RNN state for the current policy.""" - return [] - - @DeveloperAPI - def get_state(self): - """Saves all local state. - - Returns: - state (obj): Serialized local state. - """ - return self.get_weights() - - @DeveloperAPI - def set_state(self, state): - """Restores all local state. - - Arguments: - state (obj): Serialized local state. - """ - self.set_weights(state) - - @DeveloperAPI - def on_global_var_update(self, global_vars): - """Called on an update to global vars. - - Arguments: - global_vars (dict): Global variables broadcast from the driver. - """ - pass - - @DeveloperAPI - def export_model(self, export_dir): - """Export PolicyGraph to local directory for serving. - - Arguments: - export_dir (str): Local writable directory. - """ - raise NotImplementedError - - @DeveloperAPI - def export_checkpoint(self, export_dir): - """Export PolicyGraph checkpoint to local directory. - - Argument: - export_dir (str): Local writable directory. - """ - raise NotImplementedError - - -def clip_action(action, space): - """Called to clip actions to the specified range of this policy. - - Arguments: - action: Single action. - space: Action space the actions should be present in. - - Returns: - Clipped batch of actions. - """ - - if isinstance(space, gym.spaces.Box): - return np.clip(action, space.low, space.high) - elif isinstance(space, gym.spaces.Tuple): - if type(action) not in (tuple, list): - raise ValueError("Expected tuple space for actions {}: {}".format( - action, space)) - out = [] - for a, s in zip(action, space.spaces): - out.append(clip_action(a, s)) - return out - else: - return action +PolicyGraph = renamed_class(Policy, old_name="PolicyGraph") diff --git a/python/ray/rllib/evaluation/postprocessing.py b/python/ray/rllib/evaluation/postprocessing.py index aa2835f87e04..f236df6ed763 100644 --- a/python/ray/rllib/evaluation/postprocessing.py +++ b/python/ray/rllib/evaluation/postprocessing.py @@ -4,7 +4,7 @@ import numpy as np import scipy.signal -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index c80f22bdbd1a..2c0f119a94b2 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -2,295 +2,10 @@ from __future__ import division from __future__ import print_function -import six -import collections -import numpy as np +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils import renamed_class -from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI -from ray.rllib.utils.compression import pack, unpack, is_compressed -from ray.rllib.utils.memory import concat_aligned - -# Defaults policy id for single agent environments -DEFAULT_POLICY_ID = "default_policy" - - -@PublicAPI -class MultiAgentBatch(object): - """A batch of experiences from multiple policies in the environment. - - Attributes: - policy_batches (dict): Mapping from policy id to a normal SampleBatch - of experiences. Note that these batches may be of different length. - count (int): The number of timesteps in the environment this batch - contains. This will be less than the number of transitions this - batch contains across all policies in total. - """ - - @PublicAPI - def __init__(self, policy_batches, count): - self.policy_batches = policy_batches - self.count = count - - @staticmethod - @PublicAPI - def wrap_as_needed(batches, count): - if len(batches) == 1 and DEFAULT_POLICY_ID in batches: - return batches[DEFAULT_POLICY_ID] - return MultiAgentBatch(batches, count) - - @staticmethod - @PublicAPI - def concat_samples(samples): - policy_batches = collections.defaultdict(list) - total_count = 0 - for s in samples: - assert isinstance(s, MultiAgentBatch) - for policy_id, batch in s.policy_batches.items(): - policy_batches[policy_id].append(batch) - total_count += s.count - out = {} - for policy_id, batches in policy_batches.items(): - out[policy_id] = SampleBatch.concat_samples(batches) - return MultiAgentBatch(out, total_count) - - @PublicAPI - def copy(self): - return MultiAgentBatch( - {k: v.copy() - for (k, v) in self.policy_batches.items()}, self.count) - - @PublicAPI - def total(self): - ct = 0 - for batch in self.policy_batches.values(): - ct += batch.count - return ct - - @DeveloperAPI - def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): - for batch in self.policy_batches.values(): - batch.compress(bulk=bulk, columns=columns) - - @DeveloperAPI - def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): - for batch in self.policy_batches.values(): - batch.decompress_if_needed(columns) - - def __str__(self): - return "MultiAgentBatch({}, count={})".format( - str(self.policy_batches), self.count) - - def __repr__(self): - return "MultiAgentBatch({}, count={})".format( - str(self.policy_batches), self.count) - - -@PublicAPI -class SampleBatch(object): - """Wrapper around a dictionary with string keys and array-like values. - - For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three - samples, each with an "obs" and "reward" attribute. - """ - - # Outputs from interacting with the environment - CUR_OBS = "obs" - NEXT_OBS = "new_obs" - ACTIONS = "actions" - REWARDS = "rewards" - PREV_ACTIONS = "prev_actions" - PREV_REWARDS = "prev_rewards" - DONES = "dones" - INFOS = "infos" - - # Uniquely identifies an episode - EPS_ID = "eps_id" - - # Uniquely identifies a sample batch. This is important to distinguish RNN - # sequences from the same episode when multiple sample batches are - # concatenated (fusing sequences across batches can be unsafe). - UNROLL_ID = "unroll_id" - - # Uniquely identifies an agent within an episode - AGENT_INDEX = "agent_index" - - # Value function predictions emitted by the behaviour policy - VF_PREDS = "vf_preds" - - @PublicAPI - def __init__(self, *args, **kwargs): - """Constructs a sample batch (same params as dict constructor).""" - - self.data = dict(*args, **kwargs) - lengths = [] - for k, v in self.data.copy().items(): - assert isinstance(k, six.string_types), self - lengths.append(len(v)) - self.data[k] = np.array(v, copy=False) - if not lengths: - raise ValueError("Empty sample batch") - assert len(set(lengths)) == 1, "data columns must be same length" - self.count = lengths[0] - - @staticmethod - @PublicAPI - def concat_samples(samples): - if isinstance(samples[0], MultiAgentBatch): - return MultiAgentBatch.concat_samples(samples) - out = {} - samples = [s for s in samples if s.count > 0] - for k in samples[0].keys(): - out[k] = concat_aligned([s[k] for s in samples]) - return SampleBatch(out) - - @PublicAPI - def concat(self, other): - """Returns a new SampleBatch with each data column concatenated. - - Examples: - >>> b1 = SampleBatch({"a": [1, 2]}) - >>> b2 = SampleBatch({"a": [3, 4, 5]}) - >>> print(b1.concat(b2)) - {"a": [1, 2, 3, 4, 5]} - """ - - assert self.keys() == other.keys(), "must have same columns" - out = {} - for k in self.keys(): - out[k] = concat_aligned([self[k], other[k]]) - return SampleBatch(out) - - @PublicAPI - def copy(self): - return SampleBatch( - {k: np.array(v, copy=True) - for (k, v) in self.data.items()}) - - @PublicAPI - def rows(self): - """Returns an iterator over data rows, i.e. dicts with column values. - - Examples: - >>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]}) - >>> for row in batch.rows(): - print(row) - {"a": 1, "b": 4} - {"a": 2, "b": 5} - {"a": 3, "b": 6} - """ - - for i in range(self.count): - row = {} - for k in self.keys(): - row[k] = self[k][i] - yield row - - @PublicAPI - def columns(self, keys): - """Returns a list of just the specified columns. - - Examples: - >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) - >>> print(batch.columns(["a", "b"])) - [[1], [2]] - """ - - out = [] - for k in keys: - out.append(self[k]) - return out - - @PublicAPI - def shuffle(self): - """Shuffles the rows of this batch in-place.""" - - permutation = np.random.permutation(self.count) - for key, val in self.items(): - self[key] = val[permutation] - - @PublicAPI - def split_by_episode(self): - """Splits this batch's data by `eps_id`. - - Returns: - list of SampleBatch, one per distinct episode. - """ - - slices = [] - cur_eps_id = self.data["eps_id"][0] - offset = 0 - for i in range(self.count): - next_eps_id = self.data["eps_id"][i] - if next_eps_id != cur_eps_id: - slices.append(self.slice(offset, i)) - offset = i - cur_eps_id = next_eps_id - slices.append(self.slice(offset, self.count)) - for s in slices: - slen = len(set(s["eps_id"])) - assert slen == 1, (s, slen) - assert sum(s.count for s in slices) == self.count, (slices, self.count) - return slices - - @PublicAPI - def slice(self, start, end): - """Returns a slice of the row data of this batch. - - Arguments: - start (int): Starting index. - end (int): Ending index. - - Returns: - SampleBatch which has a slice of this batch's data. - """ - - return SampleBatch({k: v[start:end] for k, v in self.data.items()}) - - @PublicAPI - def keys(self): - return self.data.keys() - - @PublicAPI - def items(self): - return self.data.items() - - @PublicAPI - def __getitem__(self, key): - return self.data[key] - - @PublicAPI - def __setitem__(self, key, item): - self.data[key] = item - - @DeveloperAPI - def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): - for key in columns: - if key in self.data: - if bulk: - self.data[key] = pack(self.data[key]) - else: - self.data[key] = np.array( - [pack(o) for o in self.data[key]]) - - @DeveloperAPI - def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): - for key in columns: - if key in self.data: - arr = self.data[key] - if is_compressed(arr): - self.data[key] = unpack(arr) - elif len(arr) > 0 and is_compressed(arr[0]): - self.data[key] = np.array( - [unpack(o) for o in self.data[key]]) - - def __str__(self): - return "SampleBatch({})".format(str(self.data)) - - def __repr__(self): - return "SampleBatch({})".format(str(self.data)) - - def __iter__(self): - return self.data.__iter__() - - def __contains__(self, x): - return x in self.data +SampleBatch = renamed_class( + SampleBatch, old_name="rllib.evaluation.SampleBatch") +MultiAgentBatch = renamed_class( + MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch") diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index c6d69d7d97f1..0ead77d52847 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -6,7 +6,7 @@ import logging import numpy as np -from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize @@ -79,7 +79,7 @@ def __init__(self, policy_map, clip_rewards, postp_callback): """Initialize a MultiAgentSampleBatchBuilder. Arguments: - policy_map (dict): Maps policy ids to policy graph instances. + policy_map (dict): Maps policy ids to policy instances. clip_rewards (bool): Whether to clip rewards before postprocessing. postp_callback: function to call on each postprocessed batch. """ diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 4bc5e3cc28dc..2773ddd05b6b 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -12,7 +12,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv from ray.rllib.models.action_dist import TupleActions @@ -20,7 +20,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.evaluation.policy_graph import clip_action +from ray.rllib.policy.policy import clip_action logger = logging.getLogger(__name__) @@ -236,7 +236,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, Args: base_env (BaseEnv): env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. - policies (dict): Map of policy ids to PolicyGraph instances. + policies (dict): Map of policy ids to Policy instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. @@ -528,7 +528,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is - TFPolicyGraph.compute_actions.__code__): + TFPolicy.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index b5e38192894d..2c4955a17ff1 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -2,506 +2,7 @@ from __future__ import division from __future__ import print_function -import os -import errno -import logging -import tensorflow as tf -import numpy as np +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import renamed_class -import ray -import ray.experimental.tf_utils -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.models.lstm import chop_into_sequences -from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.debug import log_once, summarize -from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule -from ray.rllib.utils.tf_run_builder import TFRunBuilder - -logger = logging.getLogger(__name__) - - -@DeveloperAPI -class TFPolicyGraph(PolicyGraph): - """An agent policy and loss implemented in TensorFlow. - - Extending this class enables RLlib to perform TensorFlow specific - optimizations on the policy graph, e.g., parallelization across gpus or - fusing multiple graphs together in the multi-agent setting. - - Input tensors are typically shaped like [BATCH_SIZE, ...]. - - Attributes: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - model (rllib.models.Model): RLlib model used for the policy. - - Examples: - >>> policy = TFPolicyGraphSubclass( - sess, obs_input, action_sampler, loss, loss_inputs) - - >>> print(policy.compute_actions([1, 0, 2])) - (array([0, 1, 1]), [], {}) - - >>> print(policy.postprocess_trajectory(SampleBatch({...}))) - SampleBatch({"action": ..., "advantages": ..., ...}) - """ - - @DeveloperAPI - def __init__(self, - observation_space, - action_space, - sess, - obs_input, - action_sampler, - loss, - loss_inputs, - model=None, - action_prob=None, - state_inputs=None, - state_outputs=None, - prev_action_input=None, - prev_reward_input=None, - seq_lens=None, - max_seq_len=20, - batch_divisibility_req=1, - update_ops=None): - """Initialize the policy graph. - - Arguments: - observation_space (gym.Space): Observation space of the env. - action_space (gym.Space): Action space of the env. - sess (Session): TensorFlow session to use. - obs_input (Tensor): input placeholder for observations, of shape - [BATCH_SIZE, obs...]. - action_sampler (Tensor): Tensor for sampling an action, of shape - [BATCH_SIZE, action...] - loss (Tensor): scalar policy loss output tensor. - loss_inputs (list): a (name, placeholder) tuple for each loss - input argument. Each placeholder name must correspond to a - SampleBatch column key returned by postprocess_trajectory(), - and has shape [BATCH_SIZE, data...]. These keys will be read - from postprocessed sample batches and fed into the specified - placeholders during loss computation. - model (rllib.models.Model): used to integrate custom losses and - stats from user-defined RLlib models. - action_prob (Tensor): probability of the sampled action. - state_inputs (list): list of RNN state input Tensors. - state_outputs (list): list of RNN state output Tensors. - prev_action_input (Tensor): placeholder for previous actions - prev_reward_input (Tensor): placeholder for previous rewards - seq_lens (Tensor): placeholder for RNN sequence lengths, of shape - [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See - models/lstm.py for more information. - max_seq_len (int): max sequence length for LSTM training. - batch_divisibility_req (int): pad all agent experiences batches to - multiples of this value. This only has an effect if not using - a LSTM model. - update_ops (list): override the batchnorm update ops to run when - applying gradients. Otherwise we run all update ops found in - the current variable scope. - """ - - self.observation_space = observation_space - self.action_space = action_space - self.model = model - self._sess = sess - self._obs_input = obs_input - self._prev_action_input = prev_action_input - self._prev_reward_input = prev_reward_input - self._sampler = action_sampler - self._loss_inputs = loss_inputs - self._loss_input_dict = dict(self._loss_inputs) - self._is_training = self._get_is_training_placeholder() - self._action_prob = action_prob - self._state_inputs = state_inputs or [] - self._state_outputs = state_outputs or [] - for i, ph in enumerate(self._state_inputs): - self._loss_input_dict["state_in_{}".format(i)] = ph - self._seq_lens = seq_lens - self._max_seq_len = max_seq_len - self._batch_divisibility_req = batch_divisibility_req - - if self.model: - self._loss = self.model.custom_loss(loss, self._loss_input_dict) - self._stats_fetches = {"model": self.model.custom_stats()} - else: - self._loss = loss - self._stats_fetches = {} - - self._optimizer = self.optimizer() - self._grads_and_vars = [ - (g, v) for (g, v) in self.gradients(self._optimizer, self._loss) - if g is not None - ] - self._grads = [g for (g, v) in self._grads_and_vars] - self._variables = ray.experimental.tf_utils.TensorFlowVariables( - self._loss, self._sess) - - # gather update ops for any batch norm layers - if update_ops: - self._update_ops = update_ops - else: - self._update_ops = tf.get_collection( - tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) - if self._update_ops: - logger.debug("Update ops to run on apply gradient: {}".format( - self._update_ops)) - with tf.control_dependencies(self._update_ops): - self._apply_op = self.build_apply_op(self._optimizer, - self._grads_and_vars) - - if len(self._state_inputs) != len(self._state_outputs): - raise ValueError( - "Number of state input and output tensors must match, got: " - "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) - if self._state_inputs and self._seq_lens is None: - raise ValueError( - "seq_lens tensor must be given if state inputs are defined") - - logger.debug("Created {} with loss inputs: {}".format( - self, self._loss_input_dict)) - - @override(PolicyGraph) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - builder = TFRunBuilder(self._sess, "compute_actions") - fetches = self._build_compute_actions(builder, obs_batch, - state_batches, prev_action_batch, - prev_reward_batch) - return builder.get(fetches) - - @override(PolicyGraph) - def compute_gradients(self, postprocessed_batch): - builder = TFRunBuilder(self._sess, "compute_gradients") - fetches = self._build_compute_gradients(builder, postprocessed_batch) - return builder.get(fetches) - - @override(PolicyGraph) - def apply_gradients(self, gradients): - builder = TFRunBuilder(self._sess, "apply_gradients") - fetches = self._build_apply_gradients(builder, gradients) - builder.get(fetches) - - @override(PolicyGraph) - def learn_on_batch(self, postprocessed_batch): - return self._learn_on_batch(self._sess, postprocessed_batch) - - def _learn_on_batch(self, sess, postprocessed_batch): - builder = TFRunBuilder(sess, "learn_on_batch") - fetches = self._build_learn_on_batch(builder, postprocessed_batch) - return builder.get(fetches) - - @override(PolicyGraph) - def get_weights(self): - return self._variables.get_flat() - - @override(PolicyGraph) - def set_weights(self, weights): - return self._variables.set_flat(weights) - - @override(PolicyGraph) - def export_model(self, export_dir): - """Export tensorflow graph to export_dir for serving.""" - with self._sess.graph.as_default(): - builder = tf.saved_model.builder.SavedModelBuilder(export_dir) - signature_def_map = self._build_signature_def() - builder.add_meta_graph_and_variables( - self._sess, [tf.saved_model.tag_constants.SERVING], - signature_def_map=signature_def_map) - builder.save() - - @override(PolicyGraph) - def export_checkpoint(self, export_dir, filename_prefix="model"): - """Export tensorflow checkpoint to export_dir.""" - try: - os.makedirs(export_dir) - except OSError as e: - # ignore error if export dir already exists - if e.errno != errno.EEXIST: - raise - save_path = os.path.join(export_dir, filename_prefix) - with self._sess.graph.as_default(): - saver = tf.train.Saver() - saver.save(self._sess, save_path) - - @DeveloperAPI - def copy(self, existing_inputs): - """Creates a copy of self using existing input placeholders. - - Optional, only required to work with the multi-GPU optimizer.""" - raise NotImplementedError - - @DeveloperAPI - def extra_compute_action_feed_dict(self): - """Extra dict to pass to the compute actions session run.""" - return {} - - @DeveloperAPI - def extra_compute_action_fetches(self): - """Extra values to fetch and return from compute_actions(). - - By default we only return action probability info (if present). - """ - if self._action_prob is not None: - return {"action_prob": self._action_prob} - else: - return {} - - @DeveloperAPI - def extra_compute_grad_feed_dict(self): - """Extra dict to pass to the compute gradients session run.""" - return {} # e.g, kl_coeff - - @DeveloperAPI - def extra_compute_grad_fetches(self): - """Extra values to fetch and return from compute_gradients().""" - return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. - - @DeveloperAPI - def optimizer(self): - """TF optimizer to use for policy optimization.""" - return tf.train.AdamOptimizer() - - @DeveloperAPI - def gradients(self, optimizer, loss): - """Override for custom gradient computation.""" - return optimizer.compute_gradients(loss) - - @DeveloperAPI - def build_apply_op(self, optimizer, grads_and_vars): - """Override for custom gradient apply computation.""" - - # specify global_step for TD3 which needs to count the num updates - return optimizer.apply_gradients( - self._grads_and_vars, - global_step=tf.train.get_or_create_global_step()) - - @DeveloperAPI - def _get_is_training_placeholder(self): - """Get the placeholder for _is_training, i.e., for batch norm layers. - - This can be called safely before __init__ has run. - """ - if not hasattr(self, "_is_training"): - self._is_training = tf.placeholder_with_default(False, (), name='is_training') - return self._is_training - - def _extra_input_signature_def(self): - """Extra input signatures to add when exporting tf model. - Inferred from extra_compute_action_feed_dict() - """ - feed_dict = self.extra_compute_action_feed_dict() - return { - k.name: tf.saved_model.utils.build_tensor_info(k) - for k in feed_dict.keys() - } - - def _extra_output_signature_def(self): - """Extra output signatures to add when exporting tf model. - Inferred from extra_compute_action_fetches() - """ - fetches = self.extra_compute_action_fetches() - return { - k: tf.saved_model.utils.build_tensor_info(fetches[k]) - for k in fetches.keys() - } - - def _build_signature_def(self): - """Build signature def map for tensorflow SavedModelBuilder. - """ - # build input signatures - input_signature = self._extra_input_signature_def() - input_signature["observations"] = \ - tf.saved_model.utils.build_tensor_info(self._obs_input) - - if self._seq_lens is not None: - input_signature["seq_lens"] = \ - tf.saved_model.utils.build_tensor_info(self._seq_lens) - if self._prev_action_input is not None: - input_signature["prev_action"] = \ - tf.saved_model.utils.build_tensor_info(self._prev_action_input) - if self._prev_reward_input is not None: - input_signature["prev_reward"] = \ - tf.saved_model.utils.build_tensor_info(self._prev_reward_input) - input_signature["is_training"] = \ - tf.saved_model.utils.build_tensor_info(self._is_training) - - for state_input in self._state_inputs: - input_signature[state_input.name] = \ - tf.saved_model.utils.build_tensor_info(state_input) - - # build output signatures - output_signature = self._extra_output_signature_def() - output_signature["actions"] = \ - tf.saved_model.utils.build_tensor_info(self._sampler) - for state_output in self._state_outputs: - output_signature[state_output.name] = \ - tf.saved_model.utils.build_tensor_info(state_output) - signature_def = ( - tf.saved_model.signature_def_utils.build_signature_def( - input_signature, output_signature, - tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) - signature_def_key = (tf.saved_model.signature_constants. - DEFAULT_SERVING_SIGNATURE_DEF_KEY) - signature_def_map = {signature_def_key: signature_def} - return signature_def_map - - def _build_compute_actions(self, - builder, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - episodes=None): - state_batches = state_batches or [] - if len(self._state_inputs) != len(state_batches): - raise ValueError( - "Must pass in RNN state batches for placeholders {}, got {}". - format(self._state_inputs, state_batches)) - builder.add_feed_dict(self.extra_compute_action_feed_dict()) - builder.add_feed_dict({self._obs_input: obs_batch}) - if state_batches: - builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) - if self._prev_action_input is not None and prev_action_batch: - builder.add_feed_dict({self._prev_action_input: prev_action_batch}) - if self._prev_reward_input is not None and prev_reward_batch: - builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) - builder.add_feed_dict({self._is_training: False}) - builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) - fetches = builder.add_fetches([self._sampler] + self._state_outputs + - [self.extra_compute_action_fetches()]) - return fetches[0], fetches[1:-1], fetches[-1] - - def _build_compute_gradients(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - fetches = builder.add_fetches( - [self._grads, self._get_grad_and_stats_fetches()]) - return fetches[0], fetches[1] - - def _build_apply_gradients(self, builder, gradients): - if len(gradients) != len(self._grads): - raise ValueError( - "Unexpected number of gradients to apply, got {} for {}". - format(gradients, self._grads)) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(dict(zip(self._grads, gradients))) - fetches = builder.add_fetches([self._apply_op]) - return fetches[0] - - def _build_learn_on_batch(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - builder.add_feed_dict({self._is_training: True}) - fetches = builder.add_fetches([ - self._apply_op, - self._get_grad_and_stats_fetches(), - ]) - return fetches[1] - - def _get_grad_and_stats_fetches(self): - fetches = self.extra_compute_grad_fetches() - if LEARNER_STATS_KEY not in fetches: - raise ValueError( - "Grad fetches should contain 'stats': {...} entry") - if self._stats_fetches: - fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches, - **fetches[LEARNER_STATS_KEY]) - return fetches - - def _get_loss_inputs_dict(self, batch): - feed_dict = {} - if self._batch_divisibility_req > 1: - meets_divisibility_reqs = ( - len(batch[SampleBatch.CUR_OBS]) % - self._batch_divisibility_req == 0 - and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent - else: - meets_divisibility_reqs = True - - # Simple case: not RNN nor do we need to pad - if not self._state_inputs and meets_divisibility_reqs: - for k, ph in self._loss_inputs: - feed_dict[ph] = batch[k] - return feed_dict - - if self._state_inputs: - max_seq_len = self._max_seq_len - dynamic_max = True - else: - max_seq_len = self._batch_divisibility_req - dynamic_max = False - - # RNN or multi-agent case - feature_keys = [k for k, v in self._loss_inputs] - state_keys = [ - "state_in_{}".format(i) for i in range(len(self._state_inputs)) - ] - feature_sequences, initial_states, seq_lens = chop_into_sequences( - batch[SampleBatch.EPS_ID], - batch[SampleBatch.UNROLL_ID], - batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], - [batch[k] for k in state_keys], - max_seq_len, - dynamic_max=dynamic_max) - for k, v in zip(feature_keys, feature_sequences): - feed_dict[self._loss_input_dict[k]] = v - for k, v in zip(state_keys, initial_states): - feed_dict[self._loss_input_dict[k]] = v - feed_dict[self._seq_lens] = seq_lens - - if log_once("rnn_feed_dict"): - logger.info("Padded input for RNN:\n\n{}\n".format( - summarize({ - "features": feature_sequences, - "initial_states": initial_states, - "seq_lens": seq_lens, - "max_seq_len": max_seq_len, - }))) - return feed_dict - - -@DeveloperAPI -class LearningRateSchedule(object): - """Mixin for TFPolicyGraph that adds a learning rate schedule.""" - - @DeveloperAPI - def __init__(self, lr, lr_schedule): - self.cur_lr = tf.get_variable("lr", initializer=lr) - if lr_schedule is None: - self.lr_schedule = ConstantSchedule(lr) - elif isinstance(lr_schedule, list): - self.lr_schedule = PiecewiseSchedule( - lr_schedule, outside_value=lr_schedule[-1][-1]) - elif isinstance(lr_schedule, dict): - self.lr_schedule = LinearSchedule( - schedule_timesteps=lr_schedule["schedule_timesteps"], - initial_p=lr, - final_p=lr_schedule["final_lr"]) - else: - raise ValueError('lr_schedule must be either list, dict or None') - - @override(PolicyGraph) - def on_global_var_update(self, global_vars): - super(LearningRateSchedule, self).on_global_var_update(global_vars) - self.cur_lr.load( - self.lr_schedule.value(global_vars["timestep"]), - session=self._sess) - - @override(TFPolicyGraph) - def optimizer(self): - return tf.train.AdamOptimizer(self.cur_lr) +TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph") diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py new file mode 100644 index 000000000000..36f482f18bf8 --- /dev/null +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_tf_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + grad_stats_fn=None, + extra_action_fetches_fn=None, + postprocess_fn=None, + optimizer_fn=None, + gradients_fn=None, + before_init=None, + before_loss_init=None, + after_init=None, + make_action_sampler=None, + mixins=None, + get_batch_divisibility_req=None): + """Helper function for creating a dynamic tf policy at runtime. + + Arguments: + name (str): name of the policy (e.g., "PPOTFPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object + postprocess_fn (func): optional experience postprocessing function + that takes the same args as Policy.postprocess_trajectory() + optimizer_fn (func): optional function that returns a tf.Optimizer + given the policy and config + gradients_fn (func): optional function that returns a list of gradients + given a tf optimizer and loss tensor. If not specified, this + defaults to optimizer.compute_gradients(loss) + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the DynamicTFPolicy class + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + + Returns: + a DynamicTFPolicy instance that uses the specified args + """ + + if not name.endswith("TFPolicy"): + raise ValueError("Name should match *TFPolicy", name) + + base = DynamicTFPolicy + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class policy_cls(base): + def __init__(self, + obs_space, + action_space, + config, + existing_inputs=None): + if get_default_config: + config = dict(get_default_config(), **config) + + if before_init: + before_init(self, obs_space, action_space, config) + + def before_loss_init_wrapper(policy, obs_space, action_space, + config): + if before_loss_init: + before_loss_init(policy, obs_space, action_space, config) + if extra_action_fetches_fn is None: + self._extra_action_fetches = {} + else: + self._extra_action_fetches = extra_action_fetches_fn(self) + + DynamicTFPolicy.__init__( + self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=stats_fn, + grad_stats_fn=grad_stats_fn, + before_loss_init=before_loss_init_wrapper, + existing_inputs=existing_inputs) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TFPolicy) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TFPolicy.optimizer(self) + + @override(TFPolicy) + def gradients(self, optimizer, loss): + if gradients_fn: + return gradients_fn(self, optimizer, loss) + else: + return TFPolicy.gradients(self, optimizer, loss) + + @override(TFPolicy) + def extra_compute_action_fetches(self): + return dict( + TFPolicy.extra_compute_action_fetches(self), + **self._extra_action_fetches) + + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 35220dc54570..08cc29fed746 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -2,154 +2,7 @@ from __future__ import division from __future__ import print_function -import os +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.utils import renamed_class -import numpy as np -from threading import Lock - -try: - import torch -except ImportError: - pass # soft dep - -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.utils.annotations import override - - -class TorchPolicyGraph(PolicyGraph): - """Template for a PyTorch policy and loss to use with RLlib. - - This is similar to TFPolicyGraph, but for PyTorch. - - Attributes: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - lock (Lock): Lock that must be held around PyTorch ops on this graph. - This is necessary when using the async sampler. - """ - - def __init__(self, observation_space, action_space, model, loss, - loss_inputs, action_distribution_cls): - """Build a policy graph from policy and loss torch modules. - - Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES - is set. Only single GPU is supported for now. - - Arguments: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - model (nn.Module): PyTorch policy module. Given observations as - input, this module must return a list of outputs where the - first item is action logits, and the rest can be any value. - loss (nn.Module): Loss defined as a PyTorch module. The inputs for - this module are defined by the `loss_inputs` param. This module - returns a single scalar loss. Note that this module should - internally be using the model module. - loss_inputs (list): List of SampleBatch columns that will be - passed to the loss module's forward() function when computing - the loss. For example, ["obs", "action", "advantages"]. - action_distribution_cls (ActionDistribution): Class for action - distribution. - """ - self.observation_space = observation_space - self.action_space = action_space - self.lock = Lock() - self.device = (torch.device("cuda") - if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) - else torch.device("cpu")) - self._model = model.to(self.device) - self._loss = loss - self._loss_inputs = loss_inputs - self._optimizer = self.optimizer() - self._action_dist_cls = action_distribution_cls - - @override(PolicyGraph) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - with self.lock: - with torch.no_grad(): - ob = torch.from_numpy(np.array(obs_batch)) \ - .float().to(self.device) - model_out = self._model({"obs": ob}, state_batches) - logits, _, vf, state = model_out - action_dist = self._action_dist_cls(logits) - actions = action_dist.sample() - return (actions.cpu().numpy(), - [h.cpu().numpy() for h in state], - self.extra_action_out(model_out)) - - @override(PolicyGraph) - def compute_gradients(self, postprocessed_batch): - with self.lock: - loss_in = [] - for key in self._loss_inputs: - loss_in.append( - torch.from_numpy(postprocessed_batch[key]).to(self.device)) - loss_out = self._loss(self._model, *loss_in) - self._optimizer.zero_grad() - loss_out.backward() - - grad_process_info = self.extra_grad_process() - - # Note that return values are just references; - # calling zero_grad will modify the values - grads = [] - for p in self._model.parameters(): - if p.grad is not None: - grads.append(p.grad.data.cpu().numpy()) - else: - grads.append(None) - - grad_info = self.extra_grad_info() - grad_info.update(grad_process_info) - return grads, {LEARNER_STATS_KEY: grad_info} - - @override(PolicyGraph) - def apply_gradients(self, gradients): - with self.lock: - for g, p in zip(gradients, self._model.parameters()): - if g is not None: - p.grad = torch.from_numpy(g).to(self.device) - self._optimizer.step() - - @override(PolicyGraph) - def get_weights(self): - with self.lock: - return {k: v.cpu() for k, v in self._model.state_dict().items()} - - @override(PolicyGraph) - def set_weights(self, weights): - with self.lock: - self._model.load_state_dict(weights) - - @override(PolicyGraph) - def get_initial_state(self): - return [s.numpy() for s in self._model.state_init()] - - def extra_grad_process(self): - """Allow subclass to do extra processing on gradients and - return processing info.""" - return {} - - def extra_action_out(self, model_out): - """Returns dict of extra info to include in experience batch. - - Arguments: - model_out (list): Outputs of the policy model module.""" - return {} - - def extra_grad_info(self): - """Return dict of extra grad info.""" - - return {} - - def optimizer(self): - """Custom PyTorch optimizer to use.""" - return torch.optim.Adam(self._model.parameters()) +TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph") diff --git a/python/ray/rllib/examples/batch_norm_model.py b/python/ray/rllib/examples/batch_norm_model.py index 7852a62c2c24..c8a3fc83c0e4 100644 --- a/python/ray/rllib/examples/batch_norm_model.py +++ b/python/ray/rllib/examples/batch_norm_model.py @@ -5,13 +5,13 @@ import argparse -import tensorflow as tf -import tensorflow.contrib.slim as slim - import ray from ray import tune from ray.rllib.models import Model, ModelCatalog from ray.rllib.models.misc import normc_initializer +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--num-iters", type=int, default=200) @@ -24,21 +24,21 @@ def _build_layers_v2(self, input_dict, num_outputs, options): hiddens = [256, 256] for i, size in enumerate(hiddens): label = "fc{}".format(i) - last_layer = slim.fully_connected( + last_layer = tf.layers.dense( last_layer, size, - weights_initializer=normc_initializer(1.0), - activation_fn=tf.nn.tanh, - scope=label) + kernel_initializer=normc_initializer(1.0), + activation=tf.nn.tanh, + name=label) # Add a batch norm layer last_layer = tf.layers.batch_normalization( last_layer, training=input_dict["is_training"]) - output = slim.fully_connected( + output = tf.layers.dense( last_layer, num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope="fc_out") + kernel_initializer=normc_initializer(0.01), + activation=None, + name="fc_out") return output, last_layer diff --git a/python/ray/rllib/examples/carla/README b/python/ray/rllib/examples/carla/README deleted file mode 100644 index a066b048a2a1..000000000000 --- a/python/ray/rllib/examples/carla/README +++ /dev/null @@ -1,14 +0,0 @@ -(Experimental) OpenAI gym environment for https://github.com/carla-simulator/carla - -To run, first download and unpack the Carla binaries from this URL: https://github.com/carla-simulator/carla/releases/tag/0.7.0 - -Note that currently you also need to clone the Python code from `carla/benchmark_branch` which includes the Carla planner. - -Then, you can try running env.py to drive the car. Run one of the train_* scripts to attempt training. - - $ pkill -9 Carla - $ export CARLA_SERVER=/PATH/TO/CARLA_0.7.0/CarlaUE4.sh - $ export CARLA_PY_PATH=/PATH/TO/CARLA_BENCHMARK_BRANCH_REPO/PythonClient - $ python env.py - -Check out the scenarios.py file for different training and test scenarios that can be used. diff --git a/python/ray/rllib/examples/carla/env.py b/python/ray/rllib/examples/carla/env.py deleted file mode 100644 index af5b619afcdb..000000000000 --- a/python/ray/rllib/examples/carla/env.py +++ /dev/null @@ -1,684 +0,0 @@ -"""OpenAI gym environment for Carla. Run this file for a demo.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from datetime import datetime -import atexit -import cv2 -import os -import json -import random -import signal -import subprocess -import sys -import time -import traceback - -import numpy as np -try: - import scipy.misc -except Exception: - pass - -import gym -from gym.spaces import Box, Discrete, Tuple - -from scenarios import DEFAULT_SCENARIO - -# Set this where you want to save image outputs (or empty string to disable) -CARLA_OUT_PATH = os.environ.get("CARLA_OUT", os.path.expanduser("~/carla_out")) -if CARLA_OUT_PATH and not os.path.exists(CARLA_OUT_PATH): - os.makedirs(CARLA_OUT_PATH) - -# Set this to the path of your Carla binary -SERVER_BINARY = os.environ.get("CARLA_SERVER", - os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh")) - -assert os.path.exists(SERVER_BINARY) -if "CARLA_PY_PATH" in os.environ: - sys.path.append(os.path.expanduser(os.environ["CARLA_PY_PATH"])) -else: - # TODO(ekl) switch this to the binary path once the planner is in master - sys.path.append(os.path.expanduser("~/carla/PythonClient/")) - -try: - from carla.client import CarlaClient - from carla.sensor import Camera - from carla.settings import CarlaSettings - from carla.planner.planner import Planner, REACH_GOAL, GO_STRAIGHT, \ - TURN_RIGHT, TURN_LEFT, LANE_FOLLOW -except Exception as e: - print("Failed to import Carla python libs, try setting $CARLA_PY_PATH") - raise e - -# Carla planner commands -COMMANDS_ENUM = { - REACH_GOAL: "REACH_GOAL", - GO_STRAIGHT: "GO_STRAIGHT", - TURN_RIGHT: "TURN_RIGHT", - TURN_LEFT: "TURN_LEFT", - LANE_FOLLOW: "LANE_FOLLOW", -} - -# Mapping from string repr to one-hot encoding index to feed to the model -COMMAND_ORDINAL = { - "REACH_GOAL": 0, - "GO_STRAIGHT": 1, - "TURN_RIGHT": 2, - "TURN_LEFT": 3, - "LANE_FOLLOW": 4, -} - -# Number of retries if the server doesn't respond -RETRIES_ON_ERROR = 5 - -# Dummy Z coordinate to use when we only care about (x, y) -GROUND_Z = 22 - -# Default environment configuration -ENV_CONFIG = { - "log_images": True, - "enable_planner": True, - "framestack": 2, # note: only [1, 2] currently supported - "convert_images_to_video": True, - "early_terminate_on_collision": True, - "verbose": True, - "reward_function": "custom", - "render_x_res": 800, - "render_y_res": 600, - "x_res": 80, - "y_res": 80, - "server_map": "/Game/Maps/Town02", - "scenarios": [DEFAULT_SCENARIO], - "use_depth_camera": False, - "discrete_actions": True, - "squash_action_logits": False, -} - -DISCRETE_ACTIONS = { - # coast - 0: [0.0, 0.0], - # turn left - 1: [0.0, -0.5], - # turn right - 2: [0.0, 0.5], - # forward - 3: [1.0, 0.0], - # brake - 4: [-0.5, 0.0], - # forward left - 5: [1.0, -0.5], - # forward right - 6: [1.0, 0.5], - # brake left - 7: [-0.5, -0.5], - # brake right - 8: [-0.5, 0.5], -} - -live_carla_processes = set() - - -def cleanup(): - print("Killing live carla processes", live_carla_processes) - for pgid in live_carla_processes: - os.killpg(pgid, signal.SIGKILL) - - -atexit.register(cleanup) - - -class CarlaEnv(gym.Env): - def __init__(self, config=ENV_CONFIG): - self.config = config - self.city = self.config["server_map"].split("/")[-1] - if self.config["enable_planner"]: - self.planner = Planner(self.city) - - if config["discrete_actions"]: - self.action_space = Discrete(len(DISCRETE_ACTIONS)) - else: - self.action_space = Box(-1.0, 1.0, shape=(2, ), dtype=np.float32) - if config["use_depth_camera"]: - image_space = Box( - -1.0, - 1.0, - shape=(config["y_res"], config["x_res"], - 1 * config["framestack"]), - dtype=np.float32) - else: - image_space = Box( - 0, - 255, - shape=(config["y_res"], config["x_res"], - 3 * config["framestack"]), - dtype=np.uint8) - self.observation_space = Tuple( # forward_speed, dist to goal - [ - image_space, - Discrete(len(COMMANDS_ENUM)), # next_command - Box(-128.0, 128.0, shape=(2, ), dtype=np.float32) - ]) - - # TODO(ekl) this isn't really a proper gym spec - self._spec = lambda: None - self._spec.id = "Carla-v0" - - self.server_port = None - self.server_process = None - self.client = None - self.num_steps = 0 - self.total_reward = 0 - self.prev_measurement = None - self.prev_image = None - self.episode_id = None - self.measurements_file = None - self.weather = None - self.scenario = None - self.start_pos = None - self.end_pos = None - self.start_coord = None - self.end_coord = None - self.last_obs = None - - def init_server(self): - print("Initializing new Carla server...") - # Create a new server process and start the client. - self.server_port = random.randint(10000, 60000) - self.server_process = subprocess.Popen( - [ - SERVER_BINARY, self.config["server_map"], "-windowed", - "-ResX=400", "-ResY=300", "-carla-server", - "-carla-world-port={}".format(self.server_port) - ], - preexec_fn=os.setsid, - stdout=open(os.devnull, "w")) - live_carla_processes.add(os.getpgid(self.server_process.pid)) - - for i in range(RETRIES_ON_ERROR): - try: - self.client = CarlaClient("localhost", self.server_port) - return self.client.connect() - except Exception as e: - print("Error connecting: {}, attempt {}".format(e, i)) - time.sleep(2) - - def clear_server_state(self): - print("Clearing Carla server state") - try: - if self.client: - self.client.disconnect() - self.client = None - except Exception as e: - print("Error disconnecting client: {}".format(e)) - pass - if self.server_process: - pgid = os.getpgid(self.server_process.pid) - os.killpg(pgid, signal.SIGKILL) - live_carla_processes.remove(pgid) - self.server_port = None - self.server_process = None - - def __del__(self): - self.clear_server_state() - - def reset(self): - error = None - for _ in range(RETRIES_ON_ERROR): - try: - if not self.server_process: - self.init_server() - return self._reset() - except Exception as e: - print("Error during reset: {}".format(traceback.format_exc())) - self.clear_server_state() - error = e - raise error - - def _reset(self): - self.num_steps = 0 - self.total_reward = 0 - self.prev_measurement = None - self.prev_image = None - self.episode_id = datetime.today().strftime("%Y-%m-%d_%H-%M-%S_%f") - self.measurements_file = None - - # Create a CarlaSettings object. This object is a wrapper around - # the CarlaSettings.ini file. Here we set the configuration we - # want for the new episode. - settings = CarlaSettings() - self.scenario = random.choice(self.config["scenarios"]) - assert self.scenario["city"] == self.city, (self.scenario, self.city) - self.weather = random.choice(self.scenario["weather_distribution"]) - settings.set( - SynchronousMode=True, - SendNonPlayerAgentsInfo=True, - NumberOfVehicles=self.scenario["num_vehicles"], - NumberOfPedestrians=self.scenario["num_pedestrians"], - WeatherId=self.weather) - settings.randomize_seeds() - - if self.config["use_depth_camera"]: - camera1 = Camera("CameraDepth", PostProcessing="Depth") - camera1.set_image_size(self.config["render_x_res"], - self.config["render_y_res"]) - camera1.set_position(30, 0, 130) - settings.add_sensor(camera1) - - camera2 = Camera("CameraRGB") - camera2.set_image_size(self.config["render_x_res"], - self.config["render_y_res"]) - camera2.set_position(30, 0, 130) - settings.add_sensor(camera2) - - # Setup start and end positions - scene = self.client.load_settings(settings) - positions = scene.player_start_spots - self.start_pos = positions[self.scenario["start_pos_id"]] - self.end_pos = positions[self.scenario["end_pos_id"]] - self.start_coord = [ - self.start_pos.location.x // 100, self.start_pos.location.y // 100 - ] - self.end_coord = [ - self.end_pos.location.x // 100, self.end_pos.location.y // 100 - ] - print("Start pos {} ({}), end {} ({})".format( - self.scenario["start_pos_id"], self.start_coord, - self.scenario["end_pos_id"], self.end_coord)) - - # Notify the server that we want to start the episode at the - # player_start index. This function blocks until the server is ready - # to start the episode. - print("Starting new episode...") - self.client.start_episode(self.scenario["start_pos_id"]) - - image, py_measurements = self._read_observation() - self.prev_measurement = py_measurements - return self.encode_obs(self.preprocess_image(image), py_measurements) - - def encode_obs(self, image, py_measurements): - assert self.config["framestack"] in [1, 2] - prev_image = self.prev_image - self.prev_image = image - if prev_image is None: - prev_image = image - if self.config["framestack"] == 2: - image = np.concatenate([prev_image, image], axis=2) - obs = (image, COMMAND_ORDINAL[py_measurements["next_command"]], [ - py_measurements["forward_speed"], - py_measurements["distance_to_goal"] - ]) - self.last_obs = obs - return obs - - def step(self, action): - try: - obs = self._step(action) - return obs - except Exception: - print("Error during step, terminating episode early", - traceback.format_exc()) - self.clear_server_state() - return (self.last_obs, 0.0, True, {}) - - def _step(self, action): - if self.config["discrete_actions"]: - action = DISCRETE_ACTIONS[int(action)] - assert len(action) == 2, "Invalid action {}".format(action) - if self.config["squash_action_logits"]: - forward = 2 * float(sigmoid(action[0]) - 0.5) - throttle = float(np.clip(forward, 0, 1)) - brake = float(np.abs(np.clip(forward, -1, 0))) - steer = 2 * float(sigmoid(action[1]) - 0.5) - else: - throttle = float(np.clip(action[0], 0, 1)) - brake = float(np.abs(np.clip(action[0], -1, 0))) - steer = float(np.clip(action[1], -1, 1)) - reverse = False - hand_brake = False - - if self.config["verbose"]: - print("steer", steer, "throttle", throttle, "brake", brake, - "reverse", reverse) - - self.client.send_control( - steer=steer, - throttle=throttle, - brake=brake, - hand_brake=hand_brake, - reverse=reverse) - - # Process observations - image, py_measurements = self._read_observation() - if self.config["verbose"]: - print("Next command", py_measurements["next_command"]) - if type(action) is np.ndarray: - py_measurements["action"] = [float(a) for a in action] - else: - py_measurements["action"] = action - py_measurements["control"] = { - "steer": steer, - "throttle": throttle, - "brake": brake, - "reverse": reverse, - "hand_brake": hand_brake, - } - reward = compute_reward(self, self.prev_measurement, py_measurements) - self.total_reward += reward - py_measurements["reward"] = reward - py_measurements["total_reward"] = self.total_reward - done = (self.num_steps > self.scenario["max_steps"] - or py_measurements["next_command"] == "REACH_GOAL" - or (self.config["early_terminate_on_collision"] - and collided_done(py_measurements))) - py_measurements["done"] = done - self.prev_measurement = py_measurements - - # Write out measurements to file - if CARLA_OUT_PATH: - if not self.measurements_file: - self.measurements_file = open( - os.path.join( - CARLA_OUT_PATH, - "measurements_{}.json".format(self.episode_id)), "w") - self.measurements_file.write(json.dumps(py_measurements)) - self.measurements_file.write("\n") - if done: - self.measurements_file.close() - self.measurements_file = None - if self.config["convert_images_to_video"]: - self.images_to_video() - - self.num_steps += 1 - image = self.preprocess_image(image) - return (self.encode_obs(image, py_measurements), reward, done, - py_measurements) - - def images_to_video(self): - videos_dir = os.path.join(CARLA_OUT_PATH, "Videos") - if not os.path.exists(videos_dir): - os.makedirs(videos_dir) - ffmpeg_cmd = ( - "ffmpeg -loglevel -8 -r 60 -f image2 -s {x_res}x{y_res} " - "-start_number 0 -i " - "{img}_%04d.jpg -vcodec libx264 {vid}.mp4 && rm -f {img}_*.jpg " - ).format( - x_res=self.config["render_x_res"], - y_res=self.config["render_y_res"], - vid=os.path.join(videos_dir, self.episode_id), - img=os.path.join(CARLA_OUT_PATH, "CameraRGB", self.episode_id)) - print("Executing ffmpeg command", ffmpeg_cmd) - subprocess.call(ffmpeg_cmd, shell=True) - - def preprocess_image(self, image): - if self.config["use_depth_camera"]: - assert self.config["use_depth_camera"] - data = (image.data - 0.5) * 2 - data = data.reshape(self.config["render_y_res"], - self.config["render_x_res"], 1) - data = cv2.resize( - data, (self.config["x_res"], self.config["y_res"]), - interpolation=cv2.INTER_AREA) - data = np.expand_dims(data, 2) - else: - data = image.data.reshape(self.config["render_y_res"], - self.config["render_x_res"], 3) - data = cv2.resize( - data, (self.config["x_res"], self.config["y_res"]), - interpolation=cv2.INTER_AREA) - data = (data.astype(np.float32) - 128) / 128 - return data - - def _read_observation(self): - # Read the data produced by the server this frame. - measurements, sensor_data = self.client.read_data() - - # Print some of the measurements. - if self.config["verbose"]: - print_measurements(measurements) - - observation = None - if self.config["use_depth_camera"]: - camera_name = "CameraDepth" - else: - camera_name = "CameraRGB" - for name, image in sensor_data.items(): - if name == camera_name: - observation = image - - cur = measurements.player_measurements - - if self.config["enable_planner"]: - next_command = COMMANDS_ENUM[self.planner.get_next_command( - [cur.transform.location.x, cur.transform.location.y, GROUND_Z], - [ - cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z - ], - [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ - self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z - ])] - else: - next_command = "LANE_FOLLOW" - - if next_command == "REACH_GOAL": - distance_to_goal = 0.0 # avoids crash in planner - elif self.config["enable_planner"]: - distance_to_goal = self.planner.get_shortest_path_distance([ - cur.transform.location.x, cur.transform.location.y, GROUND_Z - ], [ - cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z - ], [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ - self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z - ]) / 100 - else: - distance_to_goal = -1 - - distance_to_goal_euclidean = float( - np.linalg.norm([ - cur.transform.location.x - self.end_pos.location.x, - cur.transform.location.y - self.end_pos.location.y - ]) / 100) - - py_measurements = { - "episode_id": self.episode_id, - "step": self.num_steps, - "x": cur.transform.location.x, - "y": cur.transform.location.y, - "x_orient": cur.transform.orientation.x, - "y_orient": cur.transform.orientation.y, - "forward_speed": cur.forward_speed, - "distance_to_goal": distance_to_goal, - "distance_to_goal_euclidean": distance_to_goal_euclidean, - "collision_vehicles": cur.collision_vehicles, - "collision_pedestrians": cur.collision_pedestrians, - "collision_other": cur.collision_other, - "intersection_offroad": cur.intersection_offroad, - "intersection_otherlane": cur.intersection_otherlane, - "weather": self.weather, - "map": self.config["server_map"], - "start_coord": self.start_coord, - "end_coord": self.end_coord, - "current_scenario": self.scenario, - "x_res": self.config["x_res"], - "y_res": self.config["y_res"], - "num_vehicles": self.scenario["num_vehicles"], - "num_pedestrians": self.scenario["num_pedestrians"], - "max_steps": self.scenario["max_steps"], - "next_command": next_command, - } - - if CARLA_OUT_PATH and self.config["log_images"]: - for name, image in sensor_data.items(): - out_dir = os.path.join(CARLA_OUT_PATH, name) - if not os.path.exists(out_dir): - os.makedirs(out_dir) - out_file = os.path.join( - out_dir, "{}_{:>04}.jpg".format(self.episode_id, - self.num_steps)) - scipy.misc.imsave(out_file, image.data) - - assert observation is not None, sensor_data - return observation, py_measurements - - -def compute_reward_corl2017(env, prev, current): - reward = 0.0 - - cur_dist = current["distance_to_goal"] - - prev_dist = prev["distance_to_goal"] - - if env.config["verbose"]: - print("Cur dist {}, prev dist {}".format(cur_dist, prev_dist)) - - # Distance travelled toward the goal in m - reward += np.clip(prev_dist - cur_dist, -10.0, 10.0) - - # Change in speed (km/h) - reward += 0.05 * (current["forward_speed"] - prev["forward_speed"]) - - # New collision damage - reward -= .00002 * ( - current["collision_vehicles"] + current["collision_pedestrians"] + - current["collision_other"] - prev["collision_vehicles"] - - prev["collision_pedestrians"] - prev["collision_other"]) - - # New sidewalk intersection - reward -= 2 * ( - current["intersection_offroad"] - prev["intersection_offroad"]) - - # New opposite lane intersection - reward -= 2 * ( - current["intersection_otherlane"] - prev["intersection_otherlane"]) - - return reward - - -def compute_reward_custom(env, prev, current): - reward = 0.0 - - cur_dist = current["distance_to_goal"] - prev_dist = prev["distance_to_goal"] - - if env.config["verbose"]: - print("Cur dist {}, prev dist {}".format(cur_dist, prev_dist)) - - # Distance travelled toward the goal in m - reward += np.clip(prev_dist - cur_dist, -10.0, 10.0) - - # Speed reward, up 30.0 (km/h) - reward += np.clip(current["forward_speed"], 0.0, 30.0) / 10 - - # New collision damage - new_damage = ( - current["collision_vehicles"] + current["collision_pedestrians"] + - current["collision_other"] - prev["collision_vehicles"] - - prev["collision_pedestrians"] - prev["collision_other"]) - if new_damage: - reward -= 100.0 - - # Sidewalk intersection - reward -= current["intersection_offroad"] - - # Opposite lane intersection - reward -= current["intersection_otherlane"] - - # Reached goal - if current["next_command"] == "REACH_GOAL": - reward += 100.0 - - return reward - - -def compute_reward_lane_keep(env, prev, current): - reward = 0.0 - - # Speed reward, up 30.0 (km/h) - reward += np.clip(current["forward_speed"], 0.0, 30.0) / 10 - - # New collision damage - new_damage = ( - current["collision_vehicles"] + current["collision_pedestrians"] + - current["collision_other"] - prev["collision_vehicles"] - - prev["collision_pedestrians"] - prev["collision_other"]) - if new_damage: - reward -= 100.0 - - # Sidewalk intersection - reward -= current["intersection_offroad"] - - # Opposite lane intersection - reward -= current["intersection_otherlane"] - - return reward - - -REWARD_FUNCTIONS = { - "corl2017": compute_reward_corl2017, - "custom": compute_reward_custom, - "lane_keep": compute_reward_lane_keep, -} - - -def compute_reward(env, prev, current): - return REWARD_FUNCTIONS[env.config["reward_function"]](env, prev, current) - - -def print_measurements(measurements): - number_of_agents = len(measurements.non_player_agents) - player_measurements = measurements.player_measurements - message = "Vehicle at ({pos_x:.1f}, {pos_y:.1f}), " - message += "{speed:.2f} km/h, " - message += "Collision: {{vehicles={col_cars:.0f}, " - message += "pedestrians={col_ped:.0f}, other={col_other:.0f}}}, " - message += "{other_lane:.0f}% other lane, {offroad:.0f}% off-road, " - message += "({agents_num:d} non-player agents in the scene)" - message = message.format( - pos_x=player_measurements.transform.location.x / 100, # cm -> m - pos_y=player_measurements.transform.location.y / 100, - speed=player_measurements.forward_speed, - col_cars=player_measurements.collision_vehicles, - col_ped=player_measurements.collision_pedestrians, - col_other=player_measurements.collision_other, - other_lane=100 * player_measurements.intersection_otherlane, - offroad=100 * player_measurements.intersection_offroad, - agents_num=number_of_agents) - print(message) - - -def sigmoid(x): - x = float(x) - return np.exp(x) / (1 + np.exp(x)) - - -def collided_done(py_measurements): - m = py_measurements - collided = (m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 - or m["collision_other"] > 0) - return bool(collided or m["total_reward"] < -100) - - -if __name__ == "__main__": - for _ in range(2): - env = CarlaEnv() - obs = env.reset() - print("reset", obs) - start = time.time() - done = False - i = 0 - total_reward = 0.0 - while not done: - i += 1 - if ENV_CONFIG["discrete_actions"]: - obs, reward, done, info = env.step(1) - else: - obs, reward, done, info = env.step([0, 1, 0]) - total_reward += reward - print(i, "rew", reward, "total", total_reward, "done", done) - print("{} fps".format(100 / (time.time() - start))) diff --git a/python/ray/rllib/examples/carla/models.py b/python/ray/rllib/examples/carla/models.py deleted file mode 100644 index 3f8cc0c5ba47..000000000000 --- a/python/ray/rllib/examples/carla/models.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim -from tensorflow.contrib.layers import xavier_initializer - -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.misc import normc_initializer -from ray.rllib.models.model import Model - - -class CarlaModel(Model): - """Carla model that can process the observation tuple. - - The architecture processes the image using convolutional layers, the - metrics using fully connected layers, and then combines them with - further fully connected layers. - """ - - # TODO(ekl): use build_layers_v2 for native dict space support - def _build_layers(self, inputs, num_outputs, options): - # Parse options - image_shape = options["custom_options"]["image_shape"] - convs = options.get("conv_filters", [ - [16, [8, 8], 4], - [32, [5, 5], 3], - [32, [5, 5], 2], - [512, [10, 10], 1], - ]) - hiddens = options.get("fcnet_hiddens", [64]) - fcnet_activation = options.get("fcnet_activation", "tanh") - if fcnet_activation == "tanh": - activation = tf.nn.tanh - elif fcnet_activation == "relu": - activation = tf.nn.relu - - # Sanity checks - image_size = np.product(image_shape) - expected_shape = [image_size + 5 + 2] - assert inputs.shape.as_list()[1:] == expected_shape, \ - (inputs.shape.as_list()[1:], expected_shape) - - # Reshape the input vector back into its components - vision_in = tf.reshape(inputs[:, :image_size], - [tf.shape(inputs)[0]] + image_shape) - metrics_in = inputs[:, image_size:] - print("Vision in shape", vision_in) - print("Metrics in shape", metrics_in) - - # Setup vision layers - with tf.name_scope("carla_vision"): - for i, (out_size, kernel, stride) in enumerate(convs[:-1], 1): - vision_in = slim.conv2d( - vision_in, - out_size, - kernel, - stride, - scope="conv{}".format(i)) - out_size, kernel, stride = convs[-1] - vision_in = slim.conv2d( - vision_in, - out_size, - kernel, - stride, - padding="VALID", - scope="conv_out") - vision_in = tf.squeeze(vision_in, [1, 2]) - - # Setup metrics layer - with tf.name_scope("carla_metrics"): - metrics_in = slim.fully_connected( - metrics_in, - 64, - weights_initializer=xavier_initializer(), - activation_fn=activation, - scope="metrics_out") - - print("Shape of vision out is", vision_in.shape) - print("Shape of metric out is", metrics_in.shape) - - # Combine the metrics and vision inputs - with tf.name_scope("carla_out"): - i = 1 - last_layer = tf.concat([vision_in, metrics_in], axis=1) - print("Shape of concatenated out is", last_layer.shape) - for size in hiddens: - last_layer = slim.fully_connected( - last_layer, - size, - weights_initializer=xavier_initializer(), - activation_fn=activation, - scope="fc{}".format(i)) - i += 1 - output = slim.fully_connected( - last_layer, - num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope="fc_out") - - return output, last_layer - - -def register_carla_model(): - ModelCatalog.register_custom_model("carla", CarlaModel) diff --git a/python/ray/rllib/examples/carla/scenarios.py b/python/ray/rllib/examples/carla/scenarios.py deleted file mode 100644 index beedd2989d5c..000000000000 --- a/python/ray/rllib/examples/carla/scenarios.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Collection of Carla scenarios, including those from the CoRL 2017 paper.""" - -TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13] -TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14] - - -def build_scenario(city, start, end, vehicles, pedestrians, max_steps, - weathers): - return { - "city": city, - "num_vehicles": vehicles, - "num_pedestrians": pedestrians, - "weather_distribution": weathers, - "start_pos_id": start, - "end_pos_id": end, - "max_steps": max_steps, - } - - -# Simple scenario for Town02 that involves driving down a road -DEFAULT_SCENARIO = build_scenario( - city="Town02", - start=36, - end=40, - vehicles=20, - pedestrians=40, - max_steps=200, - weathers=[0]) - -# Simple scenario for Town02 that involves driving down a road -LANE_KEEP = build_scenario( - city="Town02", - start=36, - end=40, - vehicles=0, - pedestrians=0, - max_steps=2000, - weathers=[0]) - -# Scenarios from the CoRL2017 paper -POSES_TOWN1_STRAIGHT = [[36, 40], [39, 35], [110, 114], [7, 3], [0, 4], [ - 68, 50 -], [61, 59], [47, 64], [147, 90], [33, 87], [26, 19], [80, 76], [45, 49], [ - 55, 44 -], [29, 107], [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], [20, 107], - [78, 70], [95, 102], [68, 44], [45, 69]] - -POSES_TOWN1_ONE_CURVE = [[138, 17], [47, 16], [26, 9], [42, 49], [140, 124], [ - 85, 98 -], [65, 133], [137, 51], [76, 66], [46, 39], [40, 60], [0, 29], [4, 129], [ - 121, 140 -], [2, 129], [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], [84, 69], - [47, 79], [110, 15], [130, 17], [0, 17]] - -POSES_TOWN1_NAV = [[105, 29], [27, 130], [102, 87], [132, 27], [24, 44], [ - 96, 26 -], [34, 67], [28, 1], [140, 134], [105, 9], [148, 129], [65, 18], [21, 16], [ - 147, 97 -], [42, 51], [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], [111, 64], - [79, 45], [84, 69], [73, 31], [37, 81]] - -POSES_TOWN2_STRAIGHT = [[38, 34], [4, 2], [12, 10], [62, 55], [43, 47], [ - 64, 66 -], [78, 76], [59, 57], [61, 18], [35, 39], [12, 8], [0, 18], [75, 68], [ - 54, 60 -], [45, 49], [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], [54, 63], - [51, 42], [16, 19], [17, 26], [77, 68]] - -POSES_TOWN2_ONE_CURVE = [[37, 76], [8, 24], [60, 69], [38, 10], [21, 1], [ - 58, 71 -], [74, 32], [44, 0], [71, 16], [14, 24], [34, 11], [43, 14], [75, 16], [ - 80, 21 -], [3, 23], [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], [40, 63], - [58, 76], [79, 55], [16, 61], [27, 11]] - -POSES_TOWN2_NAV = [[19, 66], [79, 14], [19, 57], [23, 1], [53, 76], [42, 13], [ - 31, 71 -], [33, 5], [54, 30], [10, 61], [66, 3], [27, 12], [79, 19], [2, 29], [16, 14], - [5, 57], [70, 73], [46, 67], [57, 50], [61, 49], [21, 12], - [51, 81], [77, 68], [56, 65], [43, 54]] - -TOWN1_STRAIGHT = [ - build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_STRAIGHT -] - -TOWN1_ONE_CURVE = [ - build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_ONE_CURVE -] - -TOWN1_NAVIGATION = [ - build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV -] - -TOWN1_NAVIGATION_DYNAMIC = [ - build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV -] - -TOWN2_STRAIGHT = [ - build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT -] - -TOWN2_STRAIGHT_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT -] - -TOWN2_ONE_CURVE = [ - build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_ONE_CURVE -] - -TOWN2_NAVIGATION = [ - build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV -] - -TOWN2_NAVIGATION_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV -] - -TOWN1_ALL = (TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION + - TOWN1_NAVIGATION_DYNAMIC) - -TOWN2_ALL = (TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION + - TOWN2_NAVIGATION_DYNAMIC) diff --git a/python/ray/rllib/examples/carla/train_a3c.py b/python/ray/rllib/examples/carla/train_a3c.py deleted file mode 100644 index 8fbcfbc576d1..000000000000 --- a/python/ray/rllib/examples/carla/train_a3c.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import grid_search, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import TOWN2_STRAIGHT - -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "squash_action_logits": grid_search([False, True]), - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": grid_search(["custom", "corl2017"]), - "scenarios": TOWN2_STRAIGHT, -}) - -register_carla_model() -redis_address = ray.services.get_node_ip_address() + ":6379" - -ray.init(redis_address=redis_address) -run_experiments({ - "carla-a3c": { - "run": "A3C", - "env": CarlaEnv, - "config": { - "env_config": env_config, - "use_gpu_for_workers": True, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "gamma": 0.95, - "num_workers": 2, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/train_dqn.py b/python/ray/rllib/examples/carla/train_dqn.py deleted file mode 100644 index 27aa65444d38..000000000000 --- a/python/ray/rllib/examples/carla/train_dqn.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import TOWN2_ONE_CURVE - -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "discrete_actions": True, - "server_map": "/Game/Maps/Town02", - "reward_function": "custom", - "scenarios": TOWN2_ONE_CURVE, -}) - -register_carla_model() - -ray.init() - - -def shape_out(spec): - return (spec.config.env_config.framestack * - (spec.config.env_config.use_depth_camera and 1 or 3)) - - -run_experiments({ - "carla-dqn": { - "run": "DQN", - "env": CarlaEnv, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [ - 80, - 80, - shape_out, - ], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "timesteps_per_iteration": 100, - "learning_starts": 1000, - "schedule_max_timesteps": 100000, - "gamma": 0.8, - "tf_session_args": { - "gpu_options": { - "allow_growth": True - }, - }, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/train_ppo.py b/python/ray/rllib/examples/carla/train_ppo.py deleted file mode 100644 index 130acf3a5849..000000000000 --- a/python/ray/rllib/examples/carla/train_ppo.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import TOWN2_STRAIGHT - -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "scenarios": TOWN2_STRAIGHT, -}) -register_carla_model() - -ray.init() -run_experiments({ - "carla": { - "run": "PPO", - "env": CarlaEnv, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [ - env_config["x_res"], env_config["y_res"], 6 - ], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "num_workers": 1, - "train_batch_size": 2000, - "sample_batch_size": 100, - "lambda": 0.95, - "clip_param": 0.2, - "num_sgd_iter": 20, - "lr": 0.0001, - "sgd_minibatch_size": 32, - "num_gpus": 1, - }, - }, -}) diff --git a/python/ray/rllib/examples/custom_fast_model.py b/python/ray/rllib/examples/custom_fast_model.py index 86201c87da7a..dce01e9e7754 100644 --- a/python/ray/rllib/examples/custom_fast_model.py +++ b/python/ray/rllib/examples/custom_fast_model.py @@ -11,11 +11,13 @@ from gym.spaces import Discrete, Box import gym import numpy as np -import tensorflow as tf import ray from ray.rllib.models import Model, ModelCatalog from ray.tune import run_experiments, sample_from +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class FastModel(Model): diff --git a/python/ray/rllib/examples/custom_loss.py b/python/ray/rllib/examples/custom_loss.py index 1f04f0fb5a6e..8905b48952da 100644 --- a/python/ray/rllib/examples/custom_loss.py +++ b/python/ray/rllib/examples/custom_loss.py @@ -15,7 +15,6 @@ import argparse import os -import tensorflow as tf import ray from ray import tune @@ -23,6 +22,9 @@ ModelCatalog) from ray.rllib.models.model import restore_original_dimensions from ray.rllib.offline import JsonReader +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--iters", type=int, default=200) diff --git a/python/ray/rllib/examples/export/cartpole_dqn_export.py b/python/ray/rllib/examples/export/cartpole_dqn_export.py index 6bfcae060d13..47a5e3b41ea7 100644 --- a/python/ray/rllib/examples/export/cartpole_dqn_export.py +++ b/python/ray/rllib/examples/export/cartpole_dqn_export.py @@ -6,9 +6,11 @@ import os import ray -import tensorflow as tf from ray.rllib.agents.registry import get_agent_class +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() ray.init(num_cpus=10) diff --git a/python/ray/rllib/examples/hierarchical_training.py b/python/ray/rllib/examples/hierarchical_training.py index c6d2db96837f..2fe61953dc96 100644 --- a/python/ray/rllib/examples/hierarchical_training.py +++ b/python/ray/rllib/examples/hierarchical_training.py @@ -209,7 +209,7 @@ def policy_mapping_fn(agent_id): "log_level": "INFO", "entropy_coeff": 0.01, "multiagent": { - "policy_graphs": { + "policies": { "high_level_policy": (None, maze.observation_space, Discrete(4), { "gamma": 0.9 diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index d7485e27a0c6..275c54390f97 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -6,7 +6,7 @@ Control the number of agents and policies via --num-agents and --num-policies. This works with hundreds of agents and policies, but note that initializing -many TF policy graphs will take some time. +many TF policies will take some time. Also, TF evals might slow down with large numbers of policies. To debug TF execution, set the TF_TIMELINE_DIR environment variable. @@ -16,20 +16,21 @@ import gym import random -import tensorflow as tf -import tensorflow.contrib.slim as slim - import ray from ray import tune from ray.rllib.models import Model, ModelCatalog from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--num-agents", type=int, default=4) parser.add_argument("--num-policies", type=int, default=2) parser.add_argument("--num-iters", type=int, default=20) +parser.add_argument("--simple", action="store_true") class CustomModel1(Model): @@ -43,12 +44,12 @@ def _build_layers_v2(self, input_dict, num_outputs, options): tf.VariableScope(tf.AUTO_REUSE, "shared"), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): - last_layer = slim.fully_connected( - input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1") - last_layer = slim.fully_connected( - last_layer, 64, activation_fn=tf.nn.relu, scope="fc2") - output = slim.fully_connected( - last_layer, num_outputs, activation_fn=None, scope="fc_out") + last_layer = tf.layers.dense( + input_dict["obs"], 64, activation=tf.nn.relu, name="fc1") + last_layer = tf.layers.dense( + last_layer, 64, activation=tf.nn.relu, name="fc2") + output = tf.layers.dense( + last_layer, num_outputs, activation=None, name="fc_out") return output, last_layer @@ -59,12 +60,12 @@ def _build_layers_v2(self, input_dict, num_outputs, options): tf.VariableScope(tf.AUTO_REUSE, "shared"), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): - last_layer = slim.fully_connected( - input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1") - last_layer = slim.fully_connected( - last_layer, 64, activation_fn=tf.nn.relu, scope="fc2") - output = slim.fully_connected( - last_layer, num_outputs, activation_fn=None, scope="fc_out") + last_layer = tf.layers.dense( + input_dict["obs"], 64, activation=tf.nn.relu, name="fc1") + last_layer = tf.layers.dense( + last_layer, 64, activation=tf.nn.relu, name="fc2") + output = tf.layers.dense( + last_layer, num_outputs, activation=None, name="fc_out") return output, last_layer @@ -90,12 +91,12 @@ def gen_policy(i): } return (None, obs_space, act_space, config) - # Setup PPO with an ensemble of `num_policies` different policy graphs - policy_graphs = { + # Setup PPO with an ensemble of `num_policies` different policies + policies = { "policy_{}".format(i): gen_policy(i) for i in range(args.num_policies) } - policy_ids = list(policy_graphs.keys()) + policy_ids = list(policies.keys()) tune.run( "PPO", @@ -103,9 +104,10 @@ def gen_policy(i): config={ "env": "multi_cartpole", "log_level": "DEBUG", + "simple_optimizer": args.simple, "num_sgd_iter": 10, "multiagent": { - "policy_graphs": policy_graphs, + "policies": policies, "policy_mapping_fn": tune.function( lambda agent_id: random.choice(policy_ids)), }, diff --git a/python/ray/rllib/examples/multiagent_custom_policy.py b/python/ray/rllib/examples/multiagent_custom_policy.py index 855051d52ef4..d34d678098b6 100644 --- a/python/ray/rllib/examples/multiagent_custom_policy.py +++ b/python/ray/rllib/examples/multiagent_custom_policy.py @@ -22,7 +22,7 @@ import ray from ray import tune -from ray.rllib.evaluation import PolicyGraph +from ray.rllib.policy import Policy from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.registry import register_env @@ -30,7 +30,7 @@ parser.add_argument("--num-iters", type=int, default=20) -class RandomPolicy(PolicyGraph): +class RandomPolicy(Policy): """Hand-coded policy that returns random actions.""" def compute_actions(self, @@ -65,7 +65,7 @@ def learn_on_batch(self, samples): config={ "env": "multi_cartpole", "multiagent": { - "policy_graphs": { + "policies": { "pg_policy": (None, obs_space, act_space, {}), "random": (RandomPolicy, obs_space, act_space, {}), }, diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index 2c18f2bf4b96..68c0e742e857 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -16,9 +16,9 @@ import ray from ray.rllib.agents.dqn.dqn import DQNTrainer -from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.agents.ppo.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.logger import pretty_print from ray.tune.registry import register_env @@ -36,11 +36,11 @@ obs_space = single_env.observation_space act_space = single_env.action_space - # You can also have multiple policy graphs per trainer, but here we just + # You can also have multiple policies per trainer, but here we just # show one each for PPO and DQN. - policy_graphs = { - "ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}), - "dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}), + policies = { + "ppo_policy": (PPOTFPolicy, obs_space, act_space, {}), + "dqn_policy": (DQNTFPolicy, obs_space, act_space, {}), } def policy_mapping_fn(agent_id): @@ -53,7 +53,7 @@ def policy_mapping_fn(agent_id): env="multi_cartpole", config={ "multiagent": { - "policy_graphs": policy_graphs, + "policies": policies, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["ppo_policy"], }, @@ -66,7 +66,7 @@ def policy_mapping_fn(agent_id): env="multi_cartpole", config={ "multiagent": { - "policy_graphs": policy_graphs, + "policies": policies, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["dqn_policy"], }, diff --git a/python/ray/rllib/examples/parametric_action_cartpole.py b/python/ray/rllib/examples/parametric_action_cartpole.py index 3d57c268cae3..e16e1ab75870 100644 --- a/python/ray/rllib/examples/parametric_action_cartpole.py +++ b/python/ray/rllib/examples/parametric_action_cartpole.py @@ -23,14 +23,15 @@ import numpy as np import gym from gym.spaces import Box, Discrete, Dict -import tensorflow as tf -import tensorflow.contrib.slim as slim import ray from ray import tune from ray.rllib.models import Model, ModelCatalog from ray.rllib.models.misc import normc_initializer from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=200) @@ -134,18 +135,18 @@ def _build_layers_v2(self, input_dict, num_outputs, options): hiddens = [256, 256] for i, size in enumerate(hiddens): label = "fc{}".format(i) - last_layer = slim.fully_connected( + last_layer = tf.layers.dense( last_layer, size, - weights_initializer=normc_initializer(1.0), - activation_fn=tf.nn.tanh, - scope=label) - output = slim.fully_connected( + kernel_initializer=normc_initializer(1.0), + activation=tf.nn.tanh, + name=label) + output = tf.layers.dense( last_layer, action_embed_size, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope="fc_out") + kernel_initializer=normc_initializer(0.01), + activation=None, + name="fc_out") # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. diff --git a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py b/python/ray/rllib/examples/policy_evaluator_custom_workflow.py index b07787129246..a8d80da994d2 100644 --- a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py +++ b/python/ray/rllib/examples/policy_evaluator_custom_workflow.py @@ -1,7 +1,7 @@ """Example of using policy evaluator classes directly to implement training. Instead of using the built-in Trainer classes provided by RLlib, here we define -a custom PolicyGraph class and manually coordinate distributed sample +a custom Policy class and manually coordinate distributed sample collection and policy optimization. """ @@ -14,7 +14,8 @@ import ray from ray import tune -from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch +from ray.rllib.policy import Policy +from ray.rllib.evaluation import PolicyEvaluator, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics parser = argparse.ArgumentParser() @@ -23,15 +24,15 @@ parser.add_argument("--num-workers", type=int, default=2) -class CustomPolicy(PolicyGraph): - """Example of a custom policy graph written from scratch. +class CustomPolicy(Policy): + """Example of a custom policy written from scratch. - You might find it more convenient to extend TF/TorchPolicyGraph instead + You might find it more convenient to extend TF/TorchPolicy instead for a real policy. """ def __init__(self, observation_space, action_space, config): - PolicyGraph.__init__(self, observation_space, action_space, config) + Policy.__init__(self, observation_space, action_space, config) # example parameter self.w = 1.0 diff --git a/python/ray/rllib/evaluation/keras_policy_graph.py b/python/ray/rllib/keras_policy.py similarity index 83% rename from python/ray/rllib/evaluation/keras_policy_graph.py rename to python/ray/rllib/keras_policy.py index 88d8e0a9be32..3008e133c1c6 100644 --- a/python/ray/rllib/evaluation/keras_policy_graph.py +++ b/python/ray/rllib/keras_policy.py @@ -4,19 +4,19 @@ import numpy as np -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy def _sample(probs): return [np.random.choice(len(pr), p=pr) for pr in probs] -class KerasPolicyGraph(PolicyGraph): - """Initialize the Keras Policy Graph. +class KerasPolicy(Policy): + """Initialize the Keras Policy. - This is a Policy Graph used for models with actor and critics. + This is a Policy used for models with actor and critics. Note: This class is built for specific usage of Actor-Critic models, - and is less general compared to TFPolicyGraph and TorchPolicyGraphs. + and is less general compared to TFPolicy and TorchPolicies. Args: observation_space (gym.Space): Observation space of the policy. @@ -32,7 +32,7 @@ def __init__(self, config, actor=None, critic=None): - PolicyGraph.__init__(self, observation_space, action_space, config) + Policy.__init__(self, observation_space, action_space, config) self.actor = actor self.critic = critic self.models = [self.actor, self.critic] diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 026a6c493e5c..9cf58b9dd317 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -4,13 +4,22 @@ from collections import namedtuple import distutils.version -import tensorflow as tf import numpy as np from ray.rllib.utils.annotations import override, DeveloperAPI - -use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.5.0")) +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +if tf: + if hasattr(tf, "__version__"): + version = tf.__version__ + else: + version = tf.VERSION + use_tf150_api = (distutils.version.LooseVersion(version) >= + distutils.version.LooseVersion("1.5.0")) +else: + use_tf150_api = False @DeveloperAPI diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index ce91742c3f5d..d237474480e5 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -5,7 +5,6 @@ import gym import logging import numpy as np -import tensorflow as tf from functools import partial from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ @@ -22,6 +21,9 @@ from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.lstm import LSTM from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 19745b9e7a3c..c3bacbd46a7d 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -2,12 +2,12 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf -import tensorflow.contrib.slim as slim - from ray.rllib.models.model import Model from ray.rllib.models.misc import normc_initializer, get_activation_fn from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class FullyConnectedNetwork(Model): @@ -29,18 +29,18 @@ def _build_layers(self, inputs, num_outputs, options): last_layer = inputs for size in hiddens: label = "fc{}".format(i) - last_layer = slim.fully_connected( + last_layer = tf.layers.dense( last_layer, size, - weights_initializer=normc_initializer(1.0), - activation_fn=activation, - scope=label) + kernel_initializer=normc_initializer(1.0), + activation=activation, + name=label) i += 1 label = "fc_out" - output = slim.fully_connected( + output = tf.layers.dense( last_layer, num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope=label) + kernel_initializer=normc_initializer(0.01), + activation=None, + name=label) return output, last_layer diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 18f141d095f9..62b854a86ed9 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -18,12 +18,13 @@ """ import numpy as np -import tensorflow as tf -import tensorflow.contrib.rnn as rnn from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.model import Model from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class LSTM(Model): @@ -73,7 +74,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): self.state_in = [c_in, h_in] # Setup LSTM outputs - state_in = rnn.LSTMStateTuple(c_in, h_in) + state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn( lstm, last_layer, diff --git a/python/ray/rllib/models/misc.py b/python/ray/rllib/models/misc.py index aad399c3b222..73ee1d87c6fd 100644 --- a/python/ray/rllib/models/misc.py +++ b/python/ray/rllib/models/misc.py @@ -2,8 +2,10 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf import numpy as np +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def normc_initializer(std=1.0): @@ -25,8 +27,11 @@ def conv2d(x, filter_size=(3, 3), stride=(1, 1), pad="SAME", - dtype=tf.float32, + dtype=None, collections=None): + if dtype is None: + dtype = tf.float32 + with tf.variable_scope(name): stride_shape = [1, stride[0], stride[1], 1] filter_shape = [ diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index b5664057d9a8..901ffa8024bf 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -5,11 +5,13 @@ from collections import OrderedDict import gym -import tensorflow as tf from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() @PublicAPI @@ -159,7 +161,7 @@ def custom_loss(self, policy_loss, loss_inputs): You can find an runnable example in examples/custom_loss.py. Arguments: - policy_loss (Tensor): scalar policy loss from the policy graph. + policy_loss (Tensor): scalar policy loss from the policy. loss_inputs (dict): map of input placeholders for rollout data. Returns: diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 432a3317c782..6ad30ddb90c4 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -2,12 +2,12 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf -import tensorflow.contrib.slim as slim - from ray.rllib.models.model import Model from ray.rllib.models.misc import get_activation_fn, flatten from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class VisionNetwork(Model): @@ -24,28 +24,29 @@ def _build_layers_v2(self, input_dict, num_outputs, options): with tf.name_scope("vision_net"): for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): - inputs = slim.conv2d( + inputs = tf.layers.conv2d( inputs, out_size, kernel, stride, - activation_fn=activation, - scope="conv{}".format(i)) + activation=activation, + padding="same", + name="conv{}".format(i)) out_size, kernel, stride = filters[-1] - fc1 = slim.conv2d( + fc1 = tf.layers.conv2d( inputs, out_size, kernel, stride, - activation_fn=activation, - padding="VALID", - scope="fc1") - fc2 = slim.conv2d( + activation=activation, + padding="valid", + name="fc1") + fc2 = tf.layers.conv2d( fc1, num_outputs, [1, 1], - activation_fn=None, - normalizer_fn=None, - scope="fc2") + activation=None, + padding="same", + name="fc2") return flatten(fc2), flatten(fc1) diff --git a/python/ray/rllib/offline/input_reader.py b/python/ray/rllib/offline/input_reader.py index bb4fe91161a2..053c279343a8 100644 --- a/python/ray/rllib/offline/input_reader.py +++ b/python/ray/rllib/offline/input_reader.py @@ -4,11 +4,13 @@ import logging import numpy as np -import tensorflow as tf import threading -from ray.rllib.evaluation.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/offline/json_reader.py b/python/ray/rllib/offline/json_reader.py index e9568e75c7f4..55a002fb3ce6 100644 --- a/python/ray/rllib/offline/json_reader.py +++ b/python/ray/rllib/offline/json_reader.py @@ -17,7 +17,7 @@ from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext -from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \ +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \ DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import unpack_if_needed diff --git a/python/ray/rllib/offline/json_writer.py b/python/ray/rllib/offline/json_writer.py index 5613d1f67dc2..679b00158b9e 100644 --- a/python/ray/rllib/offline/json_writer.py +++ b/python/ray/rllib/offline/json_writer.py @@ -15,7 +15,7 @@ except ImportError: smart_open = None -from ray.rllib.evaluation.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.output_writer import OutputWriter from ray.rllib.utils.annotations import override, PublicAPI diff --git a/python/ray/rllib/offline/off_policy_estimator.py b/python/ray/rllib/offline/off_policy_estimator.py index d09fe6baf052..7534e667f0bf 100644 --- a/python/ray/rllib/offline/off_policy_estimator.py +++ b/python/ray/rllib/offline/off_policy_estimator.py @@ -5,7 +5,7 @@ from collections import namedtuple import logging -from ray.rllib.evaluation.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import DeveloperAPI logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ def __init__(self, policy, gamma): """Creates an off-policy estimator. Arguments: - policy (PolicyGraph): Policy graph to evaluate. + policy (Policy): Policy to evaluate. gamma (float): Discount of the MDP. """ self.policy = policy @@ -71,7 +71,7 @@ def action_prob(self, batch): raise ValueError( "Off-policy estimation is not possible unless the policy " "returns action probabilities when computing actions (i.e., " - "the 'action_prob' key is output by the policy graph). You " + "the 'action_prob' key is output by the policy). You " "can set `input_evaluation: []` to resolve this.") return info["action_prob"] diff --git a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py index a584be7e6c53..b5040e45584c 100644 --- a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py +++ b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py @@ -11,12 +11,15 @@ from six.moves import queue from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.aso_learner import LearnerThread from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) @@ -38,9 +41,6 @@ def __init__(self, learner_queue_size=16, num_data_load_threads=16, _fake_gpus=False): - # Multi-GPU requires TensorFlow to function. - import tensorflow as tf - LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.lr = lr diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index b040c8e8a99f..d66f942ae532 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -17,7 +17,7 @@ import ray from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 7c3feb165e5b..aad301b29eee 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -4,9 +4,11 @@ from collections import namedtuple import logging -import tensorflow as tf from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() # Variable scope in which created variables will be placed under TOWER_SCOPE_NAME = "tower" @@ -46,7 +48,7 @@ class LocalSyncParallelOptimizer(object): processed. If this is larger than the total data size, it will be clipped. build_graph: Function that takes the specified inputs and returns a - TF Policy Graph instance. + TF Policy instance. """ def __init__(self, @@ -253,7 +255,7 @@ def optimize(self, sess, batch_index): fetches = {"train": self._train_op} for tower in self._towers: - fetches.update(tower.loss_graph.extra_compute_grad_fetches()) + fetches.update(tower.loss_graph._get_grad_and_stats_fetches()) return sess.run(fetches, feed_dict=feed_dict) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 23ee1833b9f0..a25553c40111 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -6,19 +6,21 @@ import math import numpy as np from collections import defaultdict -import tensorflow as tf import ray from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.rollout import collect_samples, \ collect_samples_straggler_mitigation from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) @@ -32,9 +34,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): details, see `multi_gpu_impl.LocalSyncParallelOptimizer`. This optimizer is Tensorflow-specific and require the underlying - PolicyGraph to be a TFPolicyGraph instance that support `.copy()`. + Policy to be a TFPolicy instance that support `.copy()`. - Note that all replicas of the TFPolicyGraph will merge their + Note that all replicas of the TFPolicy will merge their extra_compute_grad and apply_grad feed_dicts and fetches. This may result in unexpected behavior. """ @@ -81,7 +83,7 @@ def __init__(self, self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p))) logger.debug("Policies to train: {}".format(self.policies)) for policy_id, policy in self.policies.items(): - if not isinstance(policy, TFPolicyGraph): + if not isinstance(policy, TFPolicy): raise ValueError( "Only TF policies are supported with multi-GPU. Try using " "the simple optimizer instead.") @@ -220,6 +222,6 @@ def stats(self): def _averaged(kv): out = {} for k, v in kv.items(): - if v[0] is not None: + if v[0] is not None and not isinstance(v[0], dict): out[k] = np.mean(v) return out diff --git a/python/ray/rllib/optimizers/rollout.py b/python/ray/rllib/optimizers/rollout.py index 063c2ff8999d..fa1c03f6081e 100644 --- a/python/ray/rllib/optimizers/rollout.py +++ b/python/ray/rllib/optimizers/rollout.py @@ -5,7 +5,7 @@ import logging import ray -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py index 0a334e84ef79..e13d71c6e4cd 100644 --- a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py @@ -7,7 +7,7 @@ import ray from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py index 2e765f2d8641..27858f3527c1 100644 --- a/python/ray/rllib/optimizers/sync_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py @@ -11,7 +11,7 @@ PrioritizedReplayBuffer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.compression import pack_if_needed diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index a08f0345eb2b..a49b290d3e2c 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -6,7 +6,7 @@ import logging from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat @@ -69,7 +69,7 @@ def step(self): self.num_steps_sampled += samples.count self.num_steps_trained += samples.count - return fetches + return self.learner_stats @override(PolicyOptimizer) def stats(self): diff --git a/python/ray/rllib/policy/__init__.py b/python/ray/rllib/policy/__init__.py new file mode 100644 index 000000000000..0f172dcd566d --- /dev/null +++ b/python/ray/rllib/policy/__init__.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.policy.tf_policy_template import build_tf_policy + +__all__ = [ + "Policy", + "TFPolicy", + "TorchPolicy", + "build_tf_policy", + "build_torch_policy", +] diff --git a/python/ray/rllib/policy/dynamic_tf_policy.py b/python/ray/rllib/policy/dynamic_tf_policy.py new file mode 100644 index 000000000000..691fc1186272 --- /dev/null +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -0,0 +1,275 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import logging +import numpy as np + +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf +from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils.tracking_dict import UsageTrackingDict + +tf = try_import_tf() + +logger = logging.getLogger(__name__) + + +class DynamicTFPolicy(TFPolicy): + """A TFPolicy that auto-defines placeholders dynamically at runtime. + + Initialization of this class occurs in two phases. + * Phase 1: the model is created and model variables are initialized. + * Phase 2: a fake batch of data is created, sent to the trajectory + postprocessor, and then used to create placeholders for the loss + function. The loss and stats functions are initialized with these + placeholders. + """ + + def __init__(self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=None, + grad_stats_fn=None, + before_loss_init=None, + make_action_sampler=None, + existing_inputs=None, + get_batch_divisibility_req=None): + """Initialize a dynamic TF policy. + + Arguments: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + config (dict): Policy-specific configuration data. + loss_fn (func): function that returns a loss tensor the policy + graph, and dict of experience tensor placeholders + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as __init__ + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + existing_inputs (OrderedDict): when copying a policy, this + specifies an existing dict of placeholders to use instead of + defining new ones + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + """ + self.config = config + self._loss_fn = loss_fn + self._stats_fn = stats_fn + self._grad_stats_fn = grad_stats_fn + + # Setup standard placeholders + if existing_inputs is not None: + obs = existing_inputs[SampleBatch.CUR_OBS] + prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] + prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + else: + obs = tf.placeholder( + tf.float32, + shape=[None] + list(obs_space.shape), + name="observation") + prev_actions = ModelCatalog.get_action_placeholder(action_space) + prev_rewards = tf.placeholder( + tf.float32, [None], name="prev_reward") + + input_dict = { + "obs": obs, + "prev_actions": prev_actions, + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), + } + + # Create the model network and action outputs + if make_action_sampler: + assert not existing_inputs, \ + "Cloning not supported with custom action sampler" + self.model = None + self.dist_class = None + self.action_dist = None + action_sampler, action_prob = make_action_sampler( + self, input_dict, obs_space, action_space, config) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) + if existing_inputs: + existing_state_in = [ + v for k, v in existing_inputs.items() + if k.startswith("state_in_") + ] + if existing_state_in: + existing_seq_lens = existing_inputs["seq_lens"] + else: + existing_seq_lens = None + else: + existing_state_in = [] + existing_seq_lens = None + self.model = ModelCatalog.get_model( + input_dict, + obs_space, + action_space, + logit_dim, + self.config["model"], + state_in=existing_state_in, + seq_lens=existing_seq_lens) + self.action_dist = self.dist_class(self.model.outputs) + action_sampler = self.action_dist.sample() + action_prob = self.action_dist.sampled_action_prob() + + # Phase 1 init + sess = tf.get_default_session() + if get_batch_divisibility_req: + batch_divisibility_req = get_batch_divisibility_req(self) + else: + batch_divisibility_req = 1 + TFPolicy.__init__( + self, + obs_space, + action_space, + sess, + obs_input=obs, + action_sampler=action_sampler, + action_prob=action_prob, + loss=None, # dynamically initialized on run + loss_inputs=[], + model=self.model, + state_inputs=self.model and self.model.state_in, + state_outputs=self.model and self.model.state_out, + prev_action_input=prev_actions, + prev_reward_input=prev_rewards, + seq_lens=self.model and self.model.seq_lens, + max_seq_len=config["model"]["max_seq_len"], + batch_divisibility_req=batch_divisibility_req) + + # Phase 2 init + before_loss_init(self, obs_space, action_space, config) + if not existing_inputs: + self._initialize_loss() + + @override(TFPolicy) + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders.""" + + # Note that there might be RNN state inputs at the end of the list + if self._state_inputs: + num_state_inputs = len(self._state_inputs) + 1 + else: + num_state_inputs = 0 + if len(self._loss_inputs) + num_state_inputs != len(existing_inputs): + raise ValueError("Tensor list mismatch", self._loss_inputs, + self._state_inputs, existing_inputs) + for i, (k, v) in enumerate(self._loss_inputs): + if v.shape.as_list() != existing_inputs[i].shape.as_list(): + raise ValueError("Tensor shape mismatch", i, k, v.shape, + existing_inputs[i].shape) + # By convention, the loss inputs are followed by state inputs and then + # the seq len tensor + rnn_inputs = [] + for i in range(len(self._state_inputs)): + rnn_inputs.append(("state_in_{}".format(i), + existing_inputs[len(self._loss_inputs) + i])) + if rnn_inputs: + rnn_inputs.append(("seq_lens", existing_inputs[-1])) + input_dict = OrderedDict( + [(k, existing_inputs[i]) + for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs) + instance = self.__class__( + self.observation_space, + self.action_space, + self.config, + existing_inputs=input_dict) + loss = instance._loss_fn(instance, input_dict) + if instance._stats_fn: + instance._stats_fetches.update( + instance._stats_fn(instance, input_dict)) + TFPolicy._initialize_loss( + instance, loss, [(k, existing_inputs[i]) + for i, (k, _) in enumerate(self._loss_inputs)]) + if instance._grad_stats_fn: + instance._stats_fetches.update( + instance._grad_stats_fn(instance, instance._grads)) + return instance + + @override(Policy) + def get_initial_state(self): + if self.model: + return self.model.state_init + else: + return [] + + def _initialize_loss(self): + def fake_array(tensor): + shape = tensor.shape.as_list() + shape[0] = 1 + return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) + + dummy_batch = { + SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), + SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), + SampleBatch.CUR_OBS: fake_array(self._obs_input), + SampleBatch.NEXT_OBS: fake_array(self._obs_input), + SampleBatch.ACTIONS: fake_array(self._prev_action_input), + SampleBatch.REWARDS: np.array([0], dtype=np.float32), + SampleBatch.DONES: np.array([False], dtype=np.bool), + } + state_init = self.get_initial_state() + for i, h in enumerate(state_init): + dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) + dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) + if state_init: + dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) + for k, v in self.extra_compute_action_fetches().items(): + dummy_batch[k] = fake_array(v) + + # postprocessing might depend on variable init, so run it first here + self._sess.run(tf.global_variables_initializer()) + postprocessed_batch = self.postprocess_trajectory( + SampleBatch(dummy_batch)) + + batch_tensors = UsageTrackingDict({ + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.PREV_ACTIONS, self._prev_action_input), + (SampleBatch.PREV_REWARDS, self._prev_reward_input), + (SampleBatch.CUR_OBS, self._obs_input), + ] + + for k, v in postprocessed_batch.items(): + if k in batch_tensors: + continue + elif v.dtype == np.object: + continue # can't handle arbitrary objects in TF + shape = (None, ) + v.shape[1:] + dtype = np.float32 if v.dtype == np.float64 else v.dtype + placeholder = tf.placeholder(dtype, shape=shape, name=k) + batch_tensors[k] = placeholder + + if log_once("loss_init"): + logger.info( + "Initializing loss function with dummy input:\n\n{}\n".format( + summarize(batch_tensors))) + + loss = self._loss_fn(self, batch_tensors) + if self._stats_fn: + self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + for k in sorted(batch_tensors.accessed_keys): + loss_inputs.append((k, batch_tensors[k])) + TFPolicy._initialize_loss(self, loss, loss_inputs) + if self._grad_stats_fn: + self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) + self._sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/policy/policy.py b/python/ray/rllib/policy/policy.py new file mode 100644 index 000000000000..6f456e608007 --- /dev/null +++ b/python/ray/rllib/policy/policy.py @@ -0,0 +1,291 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import gym + +from ray.rllib.utils.annotations import DeveloperAPI + +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + + +@DeveloperAPI +class Policy(object): + """An agent policy and loss, i.e., a TFPolicy or other subclass. + + This object defines how to act in the environment, and also losses used to + improve the policy based on its experiences. Note that both policy and + loss are defined together for convenience, though the policy itself is + logically separate. + + All policies can directly extend Policy, however TensorFlow users may + find TFPolicy simpler to implement. TFPolicy also enables RLlib + to apply TensorFlow-specific optimizations such as fusing multiple policy + graphs and multi-GPU support. + + Attributes: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + """ + + @DeveloperAPI + def __init__(self, observation_space, action_space, config): + """Initialize the graph. + + This is the standard constructor for policies. The policy + class you pass into PolicyEvaluator will be constructed with + these arguments. + + Args: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + config (dict): Policy-specific configuration data. + """ + + self.observation_space = observation_space + self.action_space = action_space + + @DeveloperAPI + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + """Compute actions for the current policy. + + Arguments: + obs_batch (np.ndarray): batch of observations + state_batches (list): list of RNN state input batches, if any + prev_action_batch (np.ndarray): batch of previous action values + prev_reward_batch (np.ndarray): batch of previous rewards + info_batch (info): batch of info objects + episodes (list): MultiAgentEpisode for each obs in obs_batch. + This provides access to all of the internal episode state, + which may be useful for model-based or multiagent algorithms. + kwargs: forward compatibility placeholder + + Returns: + actions (np.ndarray): batch of output actions, with shape like + [BATCH_SIZE, ACTION_SHAPE]. + state_outs (list): list of RNN state output batches, if any, with + shape like [STATE_SIZE, BATCH_SIZE]. + info (dict): dictionary of extra feature batches, if any, with + shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. + """ + raise NotImplementedError + + @DeveloperAPI + def compute_single_action(self, + obs, + state, + prev_action=None, + prev_reward=None, + info=None, + episode=None, + clip_actions=False, + **kwargs): + """Unbatched version of compute_actions. + + Arguments: + obs (obj): single observation + state_batches (list): list of RNN state inputs, if any + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any + episode (MultiAgentEpisode): this provides access to all of the + internal episode state, which may be useful for model-based or + multi-agent algorithms. + clip_actions (bool): should the action be clipped + kwargs: forward compatibility placeholder + + Returns: + actions (obj): single action + state_outs (list): list of RNN state outputs, if any + info (dict): dictionary of extra features, if any + """ + + prev_action_batch = None + prev_reward_batch = None + info_batch = None + episodes = None + if prev_action is not None: + prev_action_batch = [prev_action] + if prev_reward is not None: + prev_reward_batch = [prev_reward] + if info is not None: + info_batch = [info] + if episode is not None: + episodes = [episode] + [action], state_out, info = self.compute_actions( + [obs], [[s] for s in state], + prev_action_batch=prev_action_batch, + prev_reward_batch=prev_reward_batch, + info_batch=info_batch, + episodes=episodes) + if clip_actions: + action = clip_action(action, self.action_space) + return action, [s[0] for s in state_out], \ + {k: v[0] for k, v in info.items()} + + @DeveloperAPI + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + """Implements algorithm-specific trajectory postprocessing. + + This will be called on each trajectory fragment computed during policy + evaluation. Each fragment is guaranteed to be only from one episode. + + Arguments: + sample_batch (SampleBatch): batch of experiences for the policy, + which will contain at most one episode trajectory. + other_agent_batches (dict): In a multi-agent env, this contains a + mapping of agent ids to (policy, agent_batch) tuples + containing the policy and experiences of the other agent. + episode (MultiAgentEpisode): this provides access to all of the + internal episode state, which may be useful for model-based or + multi-agent algorithms. + + Returns: + SampleBatch: postprocessed sample batch. + """ + return sample_batch + + @DeveloperAPI + def learn_on_batch(self, samples): + """Fused compute gradients and apply gradients call. + + Either this or the combination of compute/apply grads must be + implemented by subclasses. + + Returns: + grad_info: dictionary of extra metadata from compute_gradients(). + + Examples: + >>> batch = ev.sample() + >>> ev.learn_on_batch(samples) + """ + + grads, grad_info = self.compute_gradients(samples) + self.apply_gradients(grads) + return grad_info + + @DeveloperAPI + def compute_gradients(self, postprocessed_batch): + """Computes gradients against a batch of experiences. + + Either this or learn_on_batch() must be implemented by subclasses. + + Returns: + grads (list): List of gradient output values + info (dict): Extra policy-specific values + """ + raise NotImplementedError + + @DeveloperAPI + def apply_gradients(self, gradients): + """Applies previously computed gradients. + + Either this or learn_on_batch() must be implemented by subclasses. + """ + raise NotImplementedError + + @DeveloperAPI + def get_weights(self): + """Returns model weights. + + Returns: + weights (obj): Serializable copy or view of model weights + """ + raise NotImplementedError + + @DeveloperAPI + def set_weights(self, weights): + """Sets model weights. + + Arguments: + weights (obj): Serializable copy or view of model weights + """ + raise NotImplementedError + + @DeveloperAPI + def get_initial_state(self): + """Returns initial RNN state for the current policy.""" + return [] + + @DeveloperAPI + def get_state(self): + """Saves all local state. + + Returns: + state (obj): Serialized local state. + """ + return self.get_weights() + + @DeveloperAPI + def set_state(self, state): + """Restores all local state. + + Arguments: + state (obj): Serialized local state. + """ + self.set_weights(state) + + @DeveloperAPI + def on_global_var_update(self, global_vars): + """Called on an update to global vars. + + Arguments: + global_vars (dict): Global variables broadcast from the driver. + """ + pass + + @DeveloperAPI + def export_model(self, export_dir): + """Export Policy to local directory for serving. + + Arguments: + export_dir (str): Local writable directory. + """ + raise NotImplementedError + + @DeveloperAPI + def export_checkpoint(self, export_dir): + """Export Policy checkpoint to local directory. + + Argument: + export_dir (str): Local writable directory. + """ + raise NotImplementedError + + +def clip_action(action, space): + """Called to clip actions to the specified range of this policy. + + Arguments: + action: Single action. + space: Action space the actions should be present in. + + Returns: + Clipped batch of actions. + """ + + if isinstance(space, gym.spaces.Box): + return np.clip(action, space.low, space.high) + elif isinstance(space, gym.spaces.Tuple): + if type(action) not in (tuple, list): + raise ValueError("Expected tuple space for actions {}: {}".format( + action, space)) + out = [] + for a, s in zip(action, space.spaces): + out.append(clip_action(a, s)) + return out + else: + return action diff --git a/python/ray/rllib/policy/sample_batch.py b/python/ray/rllib/policy/sample_batch.py new file mode 100644 index 000000000000..a9515eeeac5a --- /dev/null +++ b/python/ray/rllib/policy/sample_batch.py @@ -0,0 +1,296 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import collections +import numpy as np + +from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI +from ray.rllib.utils.compression import pack, unpack, is_compressed +from ray.rllib.utils.memory import concat_aligned + +# Default policy id for single agent environments +DEFAULT_POLICY_ID = "default_policy" + + +@PublicAPI +class MultiAgentBatch(object): + """A batch of experiences from multiple policies in the environment. + + Attributes: + policy_batches (dict): Mapping from policy id to a normal SampleBatch + of experiences. Note that these batches may be of different length. + count (int): The number of timesteps in the environment this batch + contains. This will be less than the number of transitions this + batch contains across all policies in total. + """ + + @PublicAPI + def __init__(self, policy_batches, count): + self.policy_batches = policy_batches + self.count = count + + @staticmethod + @PublicAPI + def wrap_as_needed(batches, count): + if len(batches) == 1 and DEFAULT_POLICY_ID in batches: + return batches[DEFAULT_POLICY_ID] + return MultiAgentBatch(batches, count) + + @staticmethod + @PublicAPI + def concat_samples(samples): + policy_batches = collections.defaultdict(list) + total_count = 0 + for s in samples: + assert isinstance(s, MultiAgentBatch) + for policy_id, batch in s.policy_batches.items(): + policy_batches[policy_id].append(batch) + total_count += s.count + out = {} + for policy_id, batches in policy_batches.items(): + out[policy_id] = SampleBatch.concat_samples(batches) + return MultiAgentBatch(out, total_count) + + @PublicAPI + def copy(self): + return MultiAgentBatch( + {k: v.copy() + for (k, v) in self.policy_batches.items()}, self.count) + + @PublicAPI + def total(self): + ct = 0 + for batch in self.policy_batches.values(): + ct += batch.count + return ct + + @DeveloperAPI + def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): + for batch in self.policy_batches.values(): + batch.compress(bulk=bulk, columns=columns) + + @DeveloperAPI + def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): + for batch in self.policy_batches.values(): + batch.decompress_if_needed(columns) + + def __str__(self): + return "MultiAgentBatch({}, count={})".format( + str(self.policy_batches), self.count) + + def __repr__(self): + return "MultiAgentBatch({}, count={})".format( + str(self.policy_batches), self.count) + + +@PublicAPI +class SampleBatch(object): + """Wrapper around a dictionary with string keys and array-like values. + + For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three + samples, each with an "obs" and "reward" attribute. + """ + + # Outputs from interacting with the environment + CUR_OBS = "obs" + NEXT_OBS = "new_obs" + ACTIONS = "actions" + REWARDS = "rewards" + PREV_ACTIONS = "prev_actions" + PREV_REWARDS = "prev_rewards" + DONES = "dones" + INFOS = "infos" + + # Uniquely identifies an episode + EPS_ID = "eps_id" + + # Uniquely identifies a sample batch. This is important to distinguish RNN + # sequences from the same episode when multiple sample batches are + # concatenated (fusing sequences across batches can be unsafe). + UNROLL_ID = "unroll_id" + + # Uniquely identifies an agent within an episode + AGENT_INDEX = "agent_index" + + # Value function predictions emitted by the behaviour policy + VF_PREDS = "vf_preds" + + @PublicAPI + def __init__(self, *args, **kwargs): + """Constructs a sample batch (same params as dict constructor).""" + + self.data = dict(*args, **kwargs) + lengths = [] + for k, v in self.data.copy().items(): + assert isinstance(k, six.string_types), self + lengths.append(len(v)) + self.data[k] = np.array(v, copy=False) + if not lengths: + raise ValueError("Empty sample batch") + assert len(set(lengths)) == 1, "data columns must be same length" + self.count = lengths[0] + + @staticmethod + @PublicAPI + def concat_samples(samples): + if isinstance(samples[0], MultiAgentBatch): + return MultiAgentBatch.concat_samples(samples) + out = {} + samples = [s for s in samples if s.count > 0] + for k in samples[0].keys(): + out[k] = concat_aligned([s[k] for s in samples]) + return SampleBatch(out) + + @PublicAPI + def concat(self, other): + """Returns a new SampleBatch with each data column concatenated. + + Examples: + >>> b1 = SampleBatch({"a": [1, 2]}) + >>> b2 = SampleBatch({"a": [3, 4, 5]}) + >>> print(b1.concat(b2)) + {"a": [1, 2, 3, 4, 5]} + """ + + assert self.keys() == other.keys(), "must have same columns" + out = {} + for k in self.keys(): + out[k] = concat_aligned([self[k], other[k]]) + return SampleBatch(out) + + @PublicAPI + def copy(self): + return SampleBatch( + {k: np.array(v, copy=True) + for (k, v) in self.data.items()}) + + @PublicAPI + def rows(self): + """Returns an iterator over data rows, i.e. dicts with column values. + + Examples: + >>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> for row in batch.rows(): + print(row) + {"a": 1, "b": 4} + {"a": 2, "b": 5} + {"a": 3, "b": 6} + """ + + for i in range(self.count): + row = {} + for k in self.keys(): + row[k] = self[k][i] + yield row + + @PublicAPI + def columns(self, keys): + """Returns a list of just the specified columns. + + Examples: + >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) + >>> print(batch.columns(["a", "b"])) + [[1], [2]] + """ + + out = [] + for k in keys: + out.append(self[k]) + return out + + @PublicAPI + def shuffle(self): + """Shuffles the rows of this batch in-place.""" + + permutation = np.random.permutation(self.count) + for key, val in self.items(): + self[key] = val[permutation] + + @PublicAPI + def split_by_episode(self): + """Splits this batch's data by `eps_id`. + + Returns: + list of SampleBatch, one per distinct episode. + """ + + slices = [] + cur_eps_id = self.data["eps_id"][0] + offset = 0 + for i in range(self.count): + next_eps_id = self.data["eps_id"][i] + if next_eps_id != cur_eps_id: + slices.append(self.slice(offset, i)) + offset = i + cur_eps_id = next_eps_id + slices.append(self.slice(offset, self.count)) + for s in slices: + slen = len(set(s["eps_id"])) + assert slen == 1, (s, slen) + assert sum(s.count for s in slices) == self.count, (slices, self.count) + return slices + + @PublicAPI + def slice(self, start, end): + """Returns a slice of the row data of this batch. + + Arguments: + start (int): Starting index. + end (int): Ending index. + + Returns: + SampleBatch which has a slice of this batch's data. + """ + + return SampleBatch({k: v[start:end] for k, v in self.data.items()}) + + @PublicAPI + def keys(self): + return self.data.keys() + + @PublicAPI + def items(self): + return self.data.items() + + @PublicAPI + def __getitem__(self, key): + return self.data[key] + + @PublicAPI + def __setitem__(self, key, item): + self.data[key] = item + + @DeveloperAPI + def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): + for key in columns: + if key in self.data: + if bulk: + self.data[key] = pack(self.data[key]) + else: + self.data[key] = np.array( + [pack(o) for o in self.data[key]]) + + @DeveloperAPI + def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): + for key in columns: + if key in self.data: + arr = self.data[key] + if is_compressed(arr): + self.data[key] = unpack(arr) + elif len(arr) > 0 and is_compressed(arr[0]): + self.data[key] = np.array( + [unpack(o) for o in self.data[key]]) + + def __str__(self): + return "SampleBatch({})".format(str(self.data)) + + def __repr__(self): + return "SampleBatch({})".format(str(self.data)) + + def __iter__(self): + return self.data.__iter__() + + def __contains__(self, x): + return x in self.data diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py new file mode 100644 index 000000000000..079b1645244d --- /dev/null +++ b/python/ray/rllib/policy/tf_policy.py @@ -0,0 +1,520 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import errno +import logging +import numpy as np + +import ray +import ray.experimental.tf_utils +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.lstm import chop_into_sequences +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule +from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class TFPolicy(Policy): + """An agent policy and loss implemented in TensorFlow. + + Extending this class enables RLlib to perform TensorFlow specific + optimizations on the policy, e.g., parallelization across gpus or + fusing multiple graphs together in the multi-agent setting. + + Input tensors are typically shaped like [BATCH_SIZE, ...]. + + Attributes: + observation_space (gym.Space): observation space of the policy. + action_space (gym.Space): action space of the policy. + model (rllib.models.Model): RLlib model used for the policy. + + Examples: + >>> policy = TFPolicySubclass( + sess, obs_input, action_sampler, loss, loss_inputs) + + >>> print(policy.compute_actions([1, 0, 2])) + (array([0, 1, 1]), [], {}) + + >>> print(policy.postprocess_trajectory(SampleBatch({...}))) + SampleBatch({"action": ..., "advantages": ..., ...}) + """ + + @DeveloperAPI + def __init__(self, + observation_space, + action_space, + sess, + obs_input, + action_sampler, + loss, + loss_inputs, + model=None, + action_prob=None, + state_inputs=None, + state_outputs=None, + prev_action_input=None, + prev_reward_input=None, + seq_lens=None, + max_seq_len=20, + batch_divisibility_req=1, + update_ops=None): + """Initialize the policy. + + Arguments: + observation_space (gym.Space): Observation space of the env. + action_space (gym.Space): Action space of the env. + sess (Session): TensorFlow session to use. + obs_input (Tensor): input placeholder for observations, of shape + [BATCH_SIZE, obs...]. + action_sampler (Tensor): Tensor for sampling an action, of shape + [BATCH_SIZE, action...] + loss (Tensor): scalar policy loss output tensor. + loss_inputs (list): a (name, placeholder) tuple for each loss + input argument. Each placeholder name must correspond to a + SampleBatch column key returned by postprocess_trajectory(), + and has shape [BATCH_SIZE, data...]. These keys will be read + from postprocessed sample batches and fed into the specified + placeholders during loss computation. + model (rllib.models.Model): used to integrate custom losses and + stats from user-defined RLlib models. + action_prob (Tensor): probability of the sampled action. + state_inputs (list): list of RNN state input Tensors. + state_outputs (list): list of RNN state output Tensors. + prev_action_input (Tensor): placeholder for previous actions + prev_reward_input (Tensor): placeholder for previous rewards + seq_lens (Tensor): placeholder for RNN sequence lengths, of shape + [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See + models/lstm.py for more information. + max_seq_len (int): max sequence length for LSTM training. + batch_divisibility_req (int): pad all agent experiences batches to + multiples of this value. This only has an effect if not using + a LSTM model. + update_ops (list): override the batchnorm update ops to run when + applying gradients. Otherwise we run all update ops found in + the current variable scope. + """ + + self.observation_space = observation_space + self.action_space = action_space + self.model = model + self._sess = sess + self._obs_input = obs_input + self._prev_action_input = prev_action_input + self._prev_reward_input = prev_reward_input + self._sampler = action_sampler + self._is_training = self._get_is_training_placeholder() + self._action_prob = action_prob + self._state_inputs = state_inputs or [] + self._state_outputs = state_outputs or [] + self._seq_lens = seq_lens + self._max_seq_len = max_seq_len + self._batch_divisibility_req = batch_divisibility_req + self._update_ops = update_ops + self._stats_fetches = {} + + if loss is not None: + self._initialize_loss(loss, loss_inputs) + else: + self._loss = None + + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") + + def _initialize_loss(self, loss, loss_inputs): + self._loss_inputs = loss_inputs + self._loss_input_dict = dict(self._loss_inputs) + for i, ph in enumerate(self._state_inputs): + self._loss_input_dict["state_in_{}".format(i)] = ph + + if self.model: + self._loss = self.model.custom_loss(loss, self._loss_input_dict) + self._stats_fetches.update({"model": self.model.custom_stats()}) + else: + self._loss = loss + + self._optimizer = self.optimizer() + self._grads_and_vars = [ + (g, v) for (g, v) in self.gradients(self._optimizer, self._loss) + if g is not None + ] + self._grads = [g for (g, v) in self._grads_and_vars] + self._variables = ray.experimental.tf_utils.TensorFlowVariables( + self._loss, self._sess) + + # gather update ops for any batch norm layers + if not self._update_ops: + self._update_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) + if self._update_ops: + logger.debug("Update ops to run on apply gradient: {}".format( + self._update_ops)) + with tf.control_dependencies(self._update_ops): + self._apply_op = self.build_apply_op(self._optimizer, + self._grads_and_vars) + + if log_once("loss_used"): + logger.debug( + "These tensors were used in the loss_fn:\n\n{}\n".format( + summarize(self._loss_input_dict))) + + self._sess.run(tf.global_variables_initializer()) + + @override(Policy) + def compute_actions(self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + builder = TFRunBuilder(self._sess, "compute_actions") + fetches = self._build_compute_actions(builder, obs_batch, + state_batches, prev_action_batch, + prev_reward_batch) + return builder.get(fetches) + + @override(Policy) + def compute_gradients(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" + builder = TFRunBuilder(self._sess, "compute_gradients") + fetches = self._build_compute_gradients(builder, postprocessed_batch) + return builder.get(fetches) + + @override(Policy) + def apply_gradients(self, gradients): + assert self._loss is not None, "Loss not initialized" + builder = TFRunBuilder(self._sess, "apply_gradients") + fetches = self._build_apply_gradients(builder, gradients) + builder.get(fetches) + + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" + builder = TFRunBuilder(self._sess, "learn_on_batch") + fetches = self._build_learn_on_batch(builder, postprocessed_batch) + return builder.get(fetches) + + @override(Policy) + def get_weights(self): + return self._variables.get_flat() + + @override(Policy) + def set_weights(self, weights): + return self._variables.set_flat(weights) + + @override(Policy) + def export_model(self, export_dir): + """Export tensorflow graph to export_dir for serving.""" + with self._sess.graph.as_default(): + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + signature_def_map = self._build_signature_def() + builder.add_meta_graph_and_variables( + self._sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map) + builder.save() + + @override(Policy) + def export_checkpoint(self, export_dir, filename_prefix="model"): + """Export tensorflow checkpoint to export_dir.""" + try: + os.makedirs(export_dir) + except OSError as e: + # ignore error if export dir already exists + if e.errno != errno.EEXIST: + raise + save_path = os.path.join(export_dir, filename_prefix) + with self._sess.graph.as_default(): + saver = tf.train.Saver() + saver.save(self._sess, save_path) + + @DeveloperAPI + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders. + + Optional, only required to work with the multi-GPU optimizer.""" + raise NotImplementedError + + @DeveloperAPI + def extra_compute_action_feed_dict(self): + """Extra dict to pass to the compute actions session run.""" + return {} + + @DeveloperAPI + def extra_compute_action_fetches(self): + """Extra values to fetch and return from compute_actions(). + + By default we only return action probability info (if present). + """ + if self._action_prob is not None: + return {"action_prob": self._action_prob} + else: + return {} + + @DeveloperAPI + def extra_compute_grad_feed_dict(self): + """Extra dict to pass to the compute gradients session run.""" + return {} # e.g, kl_coeff + + @DeveloperAPI + def extra_compute_grad_fetches(self): + """Extra values to fetch and return from compute_gradients().""" + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + + @DeveloperAPI + def optimizer(self): + """TF optimizer to use for policy optimization.""" + if hasattr(self, "config"): + return tf.train.AdamOptimizer(self.config["lr"]) + else: + return tf.train.AdamOptimizer() + + @DeveloperAPI + def gradients(self, optimizer, loss): + """Override for custom gradient computation.""" + return optimizer.compute_gradients(loss) + + @DeveloperAPI + def build_apply_op(self, optimizer, grads_and_vars): + """Override for custom gradient apply computation.""" + + # specify global_step for TD3 which needs to count the num updates + return optimizer.apply_gradients( + self._grads_and_vars, + global_step=tf.train.get_or_create_global_step()) + + @DeveloperAPI + def _get_is_training_placeholder(self): + """Get the placeholder for _is_training, i.e., for batch norm layers. + + This can be called safely before __init__ has run. + """ + if not hasattr(self, "_is_training"): + self._is_training = tf.placeholder_with_default(False, ()) + return self._is_training + + def _extra_input_signature_def(self): + """Extra input signatures to add when exporting tf model. + Inferred from extra_compute_action_feed_dict() + """ + feed_dict = self.extra_compute_action_feed_dict() + return { + k.name: tf.saved_model.utils.build_tensor_info(k) + for k in feed_dict.keys() + } + + def _extra_output_signature_def(self): + """Extra output signatures to add when exporting tf model. + Inferred from extra_compute_action_fetches() + """ + fetches = self.extra_compute_action_fetches() + return { + k: tf.saved_model.utils.build_tensor_info(fetches[k]) + for k in fetches.keys() + } + + def _build_signature_def(self): + """Build signature def map for tensorflow SavedModelBuilder. + """ + # build input signatures + input_signature = self._extra_input_signature_def() + input_signature["observations"] = \ + tf.saved_model.utils.build_tensor_info(self._obs_input) + + if self._seq_lens is not None: + input_signature["seq_lens"] = \ + tf.saved_model.utils.build_tensor_info(self._seq_lens) + if self._prev_action_input is not None: + input_signature["prev_action"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_action_input) + if self._prev_reward_input is not None: + input_signature["prev_reward"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_reward_input) + input_signature["is_training"] = \ + tf.saved_model.utils.build_tensor_info(self._is_training) + + for state_input in self._state_inputs: + input_signature[state_input.name] = \ + tf.saved_model.utils.build_tensor_info(state_input) + + # build output signatures + output_signature = self._extra_output_signature_def() + output_signature["actions"] = \ + tf.saved_model.utils.build_tensor_info(self._sampler) + for state_output in self._state_outputs: + output_signature[state_output.name] = \ + tf.saved_model.utils.build_tensor_info(state_output) + signature_def = ( + tf.saved_model.signature_def_utils.build_signature_def( + input_signature, output_signature, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) + signature_def_key = (tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY) + signature_def_map = {signature_def_key: signature_def} + return signature_def_map + + def _build_compute_actions(self, + builder, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + episodes=None): + state_batches = state_batches or [] + if len(self._state_inputs) != len(state_batches): + raise ValueError( + "Must pass in RNN state batches for placeholders {}, got {}". + format(self._state_inputs, state_batches)) + builder.add_feed_dict(self.extra_compute_action_feed_dict()) + builder.add_feed_dict({self._obs_input: obs_batch}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + if self._prev_action_input is not None and prev_action_batch: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and prev_reward_batch: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) + builder.add_feed_dict({self._is_training: False}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) + fetches = builder.add_fetches([self._sampler] + self._state_outputs + + [self.extra_compute_action_fetches()]) + return fetches[0], fetches[1:-1], fetches[-1] + + def _build_compute_gradients(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + fetches = builder.add_fetches( + [self._grads, self._get_grad_and_stats_fetches()]) + return fetches[0], fetches[1] + + def _build_apply_gradients(self, builder, gradients): + if len(gradients) != len(self._grads): + raise ValueError( + "Unexpected number of gradients to apply, got {} for {}". + format(gradients, self._grads)) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(dict(zip(self._grads, gradients))) + fetches = builder.add_fetches([self._apply_op]) + return fetches[0] + + def _build_learn_on_batch(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict({self._is_training: True}) + fetches = builder.add_fetches([ + self._apply_op, + self._get_grad_and_stats_fetches(), + ]) + return fetches[1] + + def _get_grad_and_stats_fetches(self): + fetches = self.extra_compute_grad_fetches() + if LEARNER_STATS_KEY not in fetches: + raise ValueError( + "Grad fetches should contain 'stats': {...} entry") + if self._stats_fetches: + fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches, + **fetches[LEARNER_STATS_KEY]) + return fetches + + def _get_loss_inputs_dict(self, batch): + feed_dict = {} + if self._batch_divisibility_req > 1: + meets_divisibility_reqs = ( + len(batch[SampleBatch.CUR_OBS]) % + self._batch_divisibility_req == 0 + and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent + else: + meets_divisibility_reqs = True + + # Simple case: not RNN nor do we need to pad + if not self._state_inputs and meets_divisibility_reqs: + for k, ph in self._loss_inputs: + feed_dict[ph] = batch[k] + return feed_dict + + if self._state_inputs: + max_seq_len = self._max_seq_len + dynamic_max = True + else: + max_seq_len = self._batch_divisibility_req + dynamic_max = False + + # RNN or multi-agent case + feature_keys = [k for k, v in self._loss_inputs] + state_keys = [ + "state_in_{}".format(i) for i in range(len(self._state_inputs)) + ] + feature_sequences, initial_states, seq_lens = chop_into_sequences( + batch[SampleBatch.EPS_ID], + batch[SampleBatch.UNROLL_ID], + batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], + [batch[k] for k in state_keys], + max_seq_len, + dynamic_max=dynamic_max) + for k, v in zip(feature_keys, feature_sequences): + feed_dict[self._loss_input_dict[k]] = v + for k, v in zip(state_keys, initial_states): + feed_dict[self._loss_input_dict[k]] = v + feed_dict[self._seq_lens] = seq_lens + + if log_once("rnn_feed_dict"): + logger.info("Padded input for RNN:\n\n{}\n".format( + summarize({ + "features": feature_sequences, + "initial_states": initial_states, + "seq_lens": seq_lens, + "max_seq_len": max_seq_len, + }))) + return feed_dict + + +@DeveloperAPI +class LearningRateSchedule(object): + """Mixin for TFPolicy that adds a learning rate schedule.""" + + @DeveloperAPI + def __init__(self, lr, lr_schedule): + self.cur_lr = tf.get_variable("lr", initializer=lr) + if lr_schedule is None: + self.lr_schedule = ConstantSchedule(lr) + elif isinstance(lr_schedule, list): + self.lr_schedule = PiecewiseSchedule( + lr_schedule, outside_value=lr_schedule[-1][-1]) + elif isinstance(lr_schedule, dict): + self.lr_schedule = LinearSchedule( + schedule_timesteps=lr_schedule["schedule_timesteps"], + initial_p=lr, + final_p=lr_schedule["final_lr"]) + else: + raise ValueError('lr_schedule must be either list, dict or None') + + @override(Policy) + def on_global_var_update(self, global_vars): + super(LearningRateSchedule, self).on_global_var_update(global_vars) + self.cur_lr.load( + self.lr_schedule.value(global_vars["timestep"]), + session=self._sess) + + @override(TFPolicy) + def optimizer(self): + return tf.train.AdamOptimizer(self.cur_lr) diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py new file mode 100644 index 000000000000..36f482f18bf8 --- /dev/null +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_tf_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + grad_stats_fn=None, + extra_action_fetches_fn=None, + postprocess_fn=None, + optimizer_fn=None, + gradients_fn=None, + before_init=None, + before_loss_init=None, + after_init=None, + make_action_sampler=None, + mixins=None, + get_batch_divisibility_req=None): + """Helper function for creating a dynamic tf policy at runtime. + + Arguments: + name (str): name of the policy (e.g., "PPOTFPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object + postprocess_fn (func): optional experience postprocessing function + that takes the same args as Policy.postprocess_trajectory() + optimizer_fn (func): optional function that returns a tf.Optimizer + given the policy and config + gradients_fn (func): optional function that returns a list of gradients + given a tf optimizer and loss tensor. If not specified, this + defaults to optimizer.compute_gradients(loss) + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the DynamicTFPolicy class + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + + Returns: + a DynamicTFPolicy instance that uses the specified args + """ + + if not name.endswith("TFPolicy"): + raise ValueError("Name should match *TFPolicy", name) + + base = DynamicTFPolicy + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class policy_cls(base): + def __init__(self, + obs_space, + action_space, + config, + existing_inputs=None): + if get_default_config: + config = dict(get_default_config(), **config) + + if before_init: + before_init(self, obs_space, action_space, config) + + def before_loss_init_wrapper(policy, obs_space, action_space, + config): + if before_loss_init: + before_loss_init(policy, obs_space, action_space, config) + if extra_action_fetches_fn is None: + self._extra_action_fetches = {} + else: + self._extra_action_fetches = extra_action_fetches_fn(self) + + DynamicTFPolicy.__init__( + self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=stats_fn, + grad_stats_fn=grad_stats_fn, + before_loss_init=before_loss_init_wrapper, + existing_inputs=existing_inputs) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TFPolicy) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TFPolicy.optimizer(self) + + @override(TFPolicy) + def gradients(self, optimizer, loss): + if gradients_fn: + return gradients_fn(self, optimizer, loss) + else: + return TFPolicy.gradients(self, optimizer, loss) + + @override(TFPolicy) + def extra_compute_action_fetches(self): + return dict( + TFPolicy.extra_compute_action_fetches(self), + **self._extra_action_fetches) + + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/python/ray/rllib/policy/torch_policy.py b/python/ray/rllib/policy/torch_policy.py new file mode 100644 index 000000000000..633e438c5ad7 --- /dev/null +++ b/python/ray/rllib/policy/torch_policy.py @@ -0,0 +1,173 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +from threading import Lock + +try: + import torch +except ImportError: + pass # soft dep + +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.utils.annotations import override +from ray.rllib.utils.tracking_dict import UsageTrackingDict + + +class TorchPolicy(Policy): + """Template for a PyTorch policy and loss to use with RLlib. + + This is similar to TFPolicy, but for PyTorch. + + Attributes: + observation_space (gym.Space): observation space of the policy. + action_space (gym.Space): action space of the policy. + lock (Lock): Lock that must be held around PyTorch ops on this graph. + This is necessary when using the async sampler. + """ + + def __init__(self, observation_space, action_space, model, loss, + action_distribution_cls): + """Build a policy from policy and loss torch modules. + + Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES + is set. Only single GPU is supported for now. + + Arguments: + observation_space (gym.Space): observation space of the policy. + action_space (gym.Space): action space of the policy. + model (nn.Module): PyTorch policy module. Given observations as + input, this module must return a list of outputs where the + first item is action logits, and the rest can be any value. + loss (func): Function that takes (policy, batch_tensors) + and returns a single scalar loss. + action_distribution_cls (ActionDistribution): Class for action + distribution. + """ + self.observation_space = observation_space + self.action_space = action_space + self.lock = Lock() + self.device = (torch.device("cuda") + if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) + else torch.device("cpu")) + self._model = model.to(self.device) + self._loss = loss + self._optimizer = self.optimizer() + self._action_dist_cls = action_distribution_cls + + @override(Policy) + def compute_actions(self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + with self.lock: + with torch.no_grad(): + ob = torch.from_numpy(np.array(obs_batch)) \ + .float().to(self.device) + model_out = self._model({"obs": ob}, state_batches) + logits, _, vf, state = model_out + action_dist = self._action_dist_cls(logits) + actions = action_dist.sample() + return (actions.cpu().numpy(), + [h.cpu().numpy() for h in state], + self.extra_action_out(model_out)) + + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + + with self.lock: + loss_out = self._loss(self, batch_tensors) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + self._optimizer.step() + + grad_info = self.extra_grad_info(batch_tensors) + grad_info.update(grad_process_info) + return {LEARNER_STATS_KEY: grad_info} + + @override(Policy) + def compute_gradients(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + + with self.lock: + loss_out = self._loss(self, batch_tensors) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + + # Note that return values are just references; + # calling zero_grad will modify the values + grads = [] + for p in self._model.parameters(): + if p.grad is not None: + grads.append(p.grad.data.cpu().numpy()) + else: + grads.append(None) + + grad_info = self.extra_grad_info(batch_tensors) + grad_info.update(grad_process_info) + return grads, {LEARNER_STATS_KEY: grad_info} + + @override(Policy) + def apply_gradients(self, gradients): + with self.lock: + for g, p in zip(gradients, self._model.parameters()): + if g is not None: + p.grad = torch.from_numpy(g).to(self.device) + self._optimizer.step() + + @override(Policy) + def get_weights(self): + with self.lock: + return {k: v.cpu() for k, v in self._model.state_dict().items()} + + @override(Policy) + def set_weights(self, weights): + with self.lock: + self._model.load_state_dict(weights) + + @override(Policy) + def get_initial_state(self): + return [s.numpy() for s in self._model.state_init()] + + def extra_grad_process(self): + """Allow subclass to do extra processing on gradients and + return processing info.""" + return {} + + def extra_action_out(self, model_out): + """Returns dict of extra info to include in experience batch. + + Arguments: + model_out (list): Outputs of the policy model module.""" + return {} + + def extra_grad_info(self, batch_tensors): + """Return dict of extra grad info.""" + + return {} + + def optimizer(self): + """Custom PyTorch optimizer to use.""" + if hasattr(self, "config"): + return torch.optim.Adam( + self._model.parameters(), lr=self.config["lr"]) + else: + return torch.optim.Adam(self._model.parameters()) + + def _lazy_tensor_dict(self, postprocessed_batch): + batch_tensors = UsageTrackingDict(postprocessed_batch) + batch_tensors.set_get_interceptor( + lambda arr: torch.from_numpy(arr).to(self.device)) + return batch_tensors diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py new file mode 100644 index 000000000000..049591c04671 --- /dev/null +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_torch_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + postprocess_fn=None, + extra_action_out_fn=None, + extra_grad_process_fn=None, + optimizer_fn=None, + before_init=None, + after_init=None, + make_model_and_action_dist=None, + mixins=None): + """Helper function for creating a torch policy at runtime. + + Arguments: + name (str): name of the policy (e.g., "PPOTFPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + values given the policy and batch input tensors + postprocess_fn (func): optional experience postprocessing function + that takes the same args as Policy.postprocess_trajectory() + extra_action_out_fn (func): optional function that returns + a dict of extra values to include in experiences + extra_grad_process_fn (func): optional function that is called after + gradients are computed and returns processing info + optimizer_fn (func): optional function that returns a torch optimizer + given the policy and config + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_model_and_action_dist (func): optional func that takes the same + arguments as policy init and returns a tuple of model instance and + torch action distribution class. If not specified, the default + model and action dist from the catalog will be used + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the TorchPolicy class + + Returns: + a TorchPolicy instance that uses the specified args + """ + + if not name.endswith("TorchPolicy"): + raise ValueError("Name should match *TorchPolicy", name) + + base = TorchPolicy + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class graph_cls(base): + def __init__(self, obs_space, action_space, config): + if get_default_config: + config = dict(get_default_config(), **config) + self.config = config + + if before_init: + before_init(self, obs_space, action_space, config) + + if make_model_and_action_dist: + self.model, self.dist_class = make_model_and_action_dist( + self, obs_space, action_space, config) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], torch=True) + self.model = ModelCatalog.get_torch_model( + obs_space, logit_dim, self.config["model"]) + + TorchPolicy.__init__(self, obs_space, action_space, self.model, + loss_fn, self.dist_class) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TorchPolicy) + def extra_grad_process(self): + if extra_grad_process_fn: + return extra_grad_process_fn(self) + else: + return TorchPolicy.extra_grad_process(self) + + @override(TorchPolicy) + def extra_action_out(self, model_out): + if extra_action_out_fn: + return extra_action_out_fn(self, model_out) + else: + return TorchPolicy.extra_action_out(self, model_out) + + @override(TorchPolicy) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TorchPolicy.optimizer(self) + + @override(TorchPolicy) + def extra_grad_info(self, batch_tensors): + if stats_fn: + return stats_fn(self, batch_tensors) + else: + return TorchPolicy.extra_grad_info(self, batch_tensors) + + graph_cls.__name__ = name + graph_cls.__qualname__ = name + return graph_cls diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 2bb25f5c40af..efa5743c0a54 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -15,7 +15,7 @@ from ray.rllib.agents.registry import get_agent_class from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.tune.util import merge_dicts EXAMPLE_USAGE = """ diff --git a/python/ray/rllib/tests/test_catalog.py b/python/ray/rllib/tests/test_catalog.py index fe89152c6cbd..1c93b40ed484 100644 --- a/python/ray/rllib/tests/test_catalog.py +++ b/python/ray/rllib/tests/test_catalog.py @@ -1,6 +1,5 @@ import gym import numpy as np -import tensorflow as tf import unittest from gym.spaces import Box, Discrete, Tuple @@ -12,6 +11,9 @@ Preprocessor) from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class CustomPreprocessor(Preprocessor): diff --git a/python/ray/rllib/tests/test_dependency.py b/python/ray/rllib/tests/test_dependency.py new file mode 100644 index 000000000000..2df0b4b95937 --- /dev/null +++ b/python/ray/rllib/tests/test_dependency.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1" + +if __name__ == "__main__": + from ray.rllib.agents.a3c import A2CTrainer + assert "tensorflow" not in sys.modules, "TF initially present" + + # note: no ray.init(), to test it works without Ray + trainer = A2CTrainer( + env="CartPole-v0", config={ + "use_pytorch": True, + "num_workers": 0 + }) + trainer.train() + + assert "tensorflow" not in sys.modules, "TF should not be imported" diff --git a/python/ray/rllib/tests/test_evaluators.py b/python/ray/rllib/tests/test_evaluators.py index 36ded2b4e800..7f2ef740e4f5 100644 --- a/python/ray/rllib/tests/test_evaluators.py +++ b/python/ray/rllib/tests/test_evaluators.py @@ -7,7 +7,7 @@ import ray from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.agents.a3c import A3CTrainer -from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep +from ray.rllib.agents.dqn.dqn_policy import _adjust_nstep from ray.tune.registry import register_env import gym diff --git a/python/ray/rllib/tests/test_external_env.py b/python/ray/rllib/tests/test_external_env.py index 3379639612f6..3b2158959267 100644 --- a/python/ray/rllib/tests/test_external_env.py +++ b/python/ray/rllib/tests/test_external_env.py @@ -13,8 +13,8 @@ from ray.rllib.agents.pg import PGTrainer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph, - MockPolicyGraph, MockEnv) +from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy, + MockEnv) from ray.tune.registry import register_env @@ -121,7 +121,7 @@ class TestExternalEnv(unittest.TestCase): def testExternalEnvCompleteEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): @@ -131,7 +131,7 @@ def testExternalEnvCompleteEpisodes(self): def testExternalEnvTruncateEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="truncate_episodes") for _ in range(3): @@ -141,7 +141,7 @@ def testExternalEnvTruncateEpisodes(self): def testExternalEnvOffPolicy(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): @@ -153,7 +153,7 @@ def testExternalEnvOffPolicy(self): def testExternalEnvBadActions(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=BadPolicyGraph, + policy=BadPolicy, sample_async=True, batch_steps=40, batch_mode="truncate_episodes") @@ -198,7 +198,7 @@ def testTrainCartpoleMulti(self): def testExternalEnvHorizonNotSupported(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") diff --git a/python/ray/rllib/tests/test_external_multi_agent_env.py b/python/ray/rllib/tests/test_external_multi_agent_env.py index e5e182b38655..fcb3de634cbe 100644 --- a/python/ray/rllib/tests/test_external_multi_agent_env.py +++ b/python/ray/rllib/tests/test_external_multi_agent_env.py @@ -8,11 +8,11 @@ import unittest import ray -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv -from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph +from ray.rllib.tests.test_policy_evaluator import MockPolicy from ray.rllib.tests.test_external_env import make_simple_serving from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole from ray.rllib.evaluation.metrics import collect_metrics @@ -25,7 +25,7 @@ def testExternalMultiAgentEnvCompleteEpisodes(self): agents = 4 ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): @@ -37,7 +37,7 @@ def testExternalMultiAgentEnvTruncateEpisodes(self): agents = 4 ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="truncate_episodes") for _ in range(3): @@ -51,9 +51,9 @@ def testExternalMultiAgentEnvSample(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) @@ -67,12 +67,12 @@ def testTrainExternalMultiCartpoleManyPolicies(self): obs_space = single_env.observation_space policies = {} for i in range(20): - policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, + policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer(ev, []) diff --git a/python/ray/rllib/tests/test_io.py b/python/ray/rllib/tests/test_io.py index 9f92c9107c4e..c98e4553dcf1 100644 --- a/python/ray/rllib/tests/test_io.py +++ b/python/ray/rllib/tests/test_io.py @@ -15,7 +15,7 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.offline import IOContext, JsonWriter, JsonReader from ray.rllib.offline.json_writer import _to_json @@ -159,7 +159,7 @@ def testMultiAgent(self): def gen_policy(): obs_space = single_env.observation_space act_space = single_env.action_space - return (PGPolicyGraph, obs_space, act_space, {}) + return (PGTFPolicy, obs_space, act_space, {}) pg = PGTrainer( env="multi_cartpole", @@ -167,7 +167,7 @@ def gen_policy(): "num_workers": 0, "output": self.test_dir, "multiagent": { - "policy_graphs": { + "policies": { "policy_1": gen_policy(), "policy_2": gen_policy(), }, @@ -188,7 +188,7 @@ def gen_policy(): "input_evaluation": ["simulation"], "train_batch_size": 2000, "multiagent": { - "policy_graphs": { + "policies": { "policy_1": gen_policy(), "policy_2": gen_policy(), }, diff --git a/python/ray/rllib/tests/test_lstm.py b/python/ray/rllib/tests/test_lstm.py index 385f2d7bc1ba..dd9c7ccd9d86 100644 --- a/python/ray/rllib/tests/test_lstm.py +++ b/python/ray/rllib/tests/test_lstm.py @@ -6,8 +6,6 @@ import numpy as np import pickle import unittest -import tensorflow as tf -import tensorflow.contrib.rnn as rnn import ray from ray.rllib.agents.ppo import PPOTrainer @@ -16,6 +14,9 @@ from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.model import Model from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class LSTMUtilsTest(unittest.TestCase): @@ -104,7 +105,7 @@ def spy(sequences, state_in, state_out, seq_lens): last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell - lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) + lstm = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32) @@ -121,7 +122,7 @@ def spy(sequences, state_in, state_out, seq_lens): self.state_in = [c_in, h_in] # Setup LSTM outputs - state_in = rnn.LSTMStateTuple(c_in, h_in) + state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn( lstm, last_layer, diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index eccb9aa82fb8..be4bfcd3428f 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -8,14 +8,14 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph -from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.agents.pg.pg_policy import PGTFPolicy +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer, AsyncGradientsOptimizer) from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2, - MockPolicyGraph) + MockPolicy) from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -329,9 +329,9 @@ def testMultiAgentSample(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) @@ -347,9 +347,9 @@ def testMultiAgentSampleSyncRemote(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50, @@ -364,9 +364,9 @@ def testMultiAgentSampleAsyncRemote(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50, @@ -380,9 +380,9 @@ def testMultiAgentSampleWithHorizon(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), episode_horizon=10, # test with episode horizon set @@ -395,9 +395,9 @@ def testSampleFromEarlyDoneEnv(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: EarlyDoneMultiAgent(), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_mode="complete_episodes", @@ -411,8 +411,8 @@ def testMultiAgentSampleRoundRobin(self): obs_space = gym.spaces.Discrete(10) ev = PolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=50) @@ -445,7 +445,7 @@ def testMultiAgentSampleRoundRobin(self): def testCustomRNNStateValues(self): h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}} - class StatefulPolicyGraph(PolicyGraph): + class StatefulPolicy(Policy): def compute_actions(self, obs_batch, state_batches, @@ -460,7 +460,7 @@ def get_initial_state(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=StatefulPolicyGraph, + policy=StatefulPolicy, batch_steps=5) batch = ev.sample() self.assertEqual(batch.count, 5) @@ -470,7 +470,7 @@ def get_initial_state(self): self.assertEqual(batch["state_out_0"][1], h) def testReturningModelBasedRolloutsData(self): - class ModelBasedPolicyGraph(PGPolicyGraph): + class ModelBasedPolicy(PGTFPolicy): def compute_actions(self, obs_batch, state_batches, @@ -505,9 +505,9 @@ def compute_actions(self, act_space = single_env.action_space ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(2), - policy_graph={ - "p0": (ModelBasedPolicyGraph, obs_space, act_space, {}), - "p1": (ModelBasedPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (ModelBasedPolicy, obs_space, act_space, {}), + "p1": (ModelBasedPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=5) @@ -547,7 +547,7 @@ def gen_policy(): config={ "num_workers": 0, "multiagent": { - "policy_graphs": { + "policies": { "policy_1": gen_policy(), "policy_2": gen_policy(), }, @@ -579,17 +579,17 @@ def _testWithOptimizer(self, optimizer_cls): # happen since the replay buffer doesn't encode extra fields like # "advantages" that PG uses. policies = { - "p1": (DQNPolicyGraph, obs_space, act_space, dqn_config), - "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), + "p1": (DQNTFPolicy, obs_space, act_space, dqn_config), + "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } else: policies = { - "p1": (PGPolicyGraph, obs_space, act_space, {}), - "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), + "p1": (PGTFPolicy, obs_space, act_space, {}), + "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50) if optimizer_cls == AsyncGradientsOptimizer: @@ -600,7 +600,7 @@ def policy_mapper(agent_id): remote_evs = [ PolicyEvaluator.as_remote().remote( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=policy_mapper, batch_steps=50) ] @@ -610,12 +610,16 @@ def policy_mapper(agent_id): for i in range(200): ev.foreach_policy(lambda p, _: p.set_epsilon( max(0.02, 1 - i * .02)) - if isinstance(p, DQNPolicyGraph) else None) + if isinstance(p, DQNTFPolicy) else None) optimizer.step() result = collect_metrics(ev, remote_evs) if i % 20 == 0: - ev.foreach_policy(lambda p, _: p.update_target() if isinstance( - p, DQNPolicyGraph) else None) + + def do_update(p): + if isinstance(p, DQNTFPolicy): + p.update_target() + + ev.foreach_policy(lambda p, _: do_update(p)) print("Iter {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) @@ -640,12 +644,12 @@ def testTrainMultiCartpoleManyPolicies(self): obs_space = env.observation_space policies = {} for i in range(20): - policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, + policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer(ev, []) diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index dc45ca3f605e..0220ba01722c 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -7,14 +7,12 @@ from gym import spaces from gym.envs.registration import EnvSpec import gym -import tensorflow.contrib.slim as slim -import tensorflow as tf import unittest import ray from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv @@ -25,6 +23,9 @@ from ray.rllib.rollout import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() DICT_SPACE = spaces.Dict({ "sensors": spaces.Dict({ @@ -179,8 +180,8 @@ def spy(pos, front_cam, task): stateful=True) with tf.control_dependencies([spy_fn]): - output = slim.fully_connected( - input_dict["obs"]["sensors"]["position"], num_outputs) + output = tf.layers.dense(input_dict["obs"]["sensors"]["position"], + num_outputs) return output, output @@ -208,7 +209,7 @@ def spy(pos, cam, task): stateful=True) with tf.control_dependencies([spy_fn]): - output = slim.fully_connected(input_dict["obs"][0], num_outputs) + output = tf.layers.dense(input_dict["obs"][0], num_outputs) return output, output @@ -330,12 +331,12 @@ def testMultiAgentComplexSpaces(self): "sample_batch_size": 5, "train_batch_size": 5, "multiagent": { - "policy_graphs": { + "policies": { "tuple_policy": ( - PGPolicyGraph, TUPLE_SPACE, act_space, + PGTFPolicy, TUPLE_SPACE, act_space, {"model": {"custom_model": "tuple_spy"}}), "dict_policy": ( - PGPolicyGraph, DICT_SPACE, act_space, + PGTFPolicy, DICT_SPACE, act_space, {"model": {"custom_model": "dict_spy"}}), }, "policy_mapping_fn": lambda a: { diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index 65992a220ba2..f851cfc33f12 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -4,18 +4,20 @@ import gym import numpy as np -import tensorflow as tf import time import unittest import ray from ray.rllib.agents.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.tests.mock_evaluator import _MockEvaluator +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class AsyncOptimizerTest(unittest.TestCase): @@ -238,12 +240,12 @@ def make_sess(): local = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOPolicyGraph, + policy=PPOTFPolicy, tf_session_creator=make_sess) remotes = [ PolicyEvaluator.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOPolicyGraph, + policy=PPOTFPolicy, tf_session_creator=make_sess) ] return local, remotes diff --git a/python/ray/rllib/tests/test_perf.py b/python/ray/rllib/tests/test_perf.py index f437c9628dfd..e31530f44ced 100644 --- a/python/ray/rllib/tests/test_perf.py +++ b/python/ray/rllib/tests/test_perf.py @@ -8,7 +8,7 @@ import ray from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph +from ray.rllib.tests.test_policy_evaluator import MockPolicy class TestPerf(unittest.TestCase): @@ -19,7 +19,7 @@ def testBaselinePerformance(self): for _ in range(20): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=100) start = time.time() count = 0 diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index 6283a5b66314..dc0dcaff6782 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -14,14 +14,14 @@ from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.env.vector_env import VectorEnv from ray.tune.registry import register_env -class MockPolicyGraph(PolicyGraph): +class MockPolicy(Policy): def compute_actions(self, obs_batch, state_batches, @@ -39,7 +39,7 @@ def postprocess_trajectory(self, return compute_advantages(batch, 100.0, 0.9, use_gae=False) -class BadPolicyGraph(PolicyGraph): +class BadPolicy(Policy): def compute_actions(self, obs_batch, state_batches, @@ -132,8 +132,7 @@ def get_unwrapped(self): class TestPolicyEvaluator(unittest.TestCase): def testBasic(self): ev = PolicyEvaluator( - env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph) + env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy) batch = ev.sample() for key in [ "obs", "actions", "rewards", "dones", "advantages", @@ -157,8 +156,7 @@ def to_prev(vec): def testBatchIds(self): ev = PolicyEvaluator( - env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph) + env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy) batch1 = ev.sample() batch2 = ev.sample() self.assertEqual(len(set(batch1["unroll_id"])), 1) @@ -229,7 +227,7 @@ def testRewardClipping(self): # clipping on ev = PolicyEvaluator( env_creator=lambda _: MockEnv2(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, clip_rewards=True, batch_mode="complete_episodes") self.assertEqual(max(ev.sample()["rewards"]), 1) @@ -239,7 +237,7 @@ def testRewardClipping(self): # clipping off ev2 = PolicyEvaluator( env_creator=lambda _: MockEnv2(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, clip_rewards=False, batch_mode="complete_episodes") self.assertEqual(max(ev2.sample()["rewards"]), 100) @@ -249,7 +247,7 @@ def testRewardClipping(self): def testHardHorizon(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes", batch_steps=10, episode_horizon=4, @@ -263,7 +261,7 @@ def testHardHorizon(self): def testSoftHorizon(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes", batch_steps=10, episode_horizon=4, @@ -277,11 +275,11 @@ def testSoftHorizon(self): def testMetrics(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes") remote_ev = PolicyEvaluator.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes") ev.sample() ray.get(remote_ev.sample.remote()) @@ -293,7 +291,7 @@ def testAsync(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), sample_async=True, - policy_graph=MockPolicyGraph) + policy=MockPolicy) batch = ev.sample() for key in ["obs", "actions", "rewards", "dones", "advantages"]: self.assertIn(key, batch) @@ -302,7 +300,7 @@ def testAsync(self): def testAutoVectorization(self): ev = PolicyEvaluator( env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="truncate_episodes", batch_steps=2, num_envs=8) @@ -325,7 +323,7 @@ def testAutoVectorization(self): def testBatchesLargerWhenVectorized(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="truncate_episodes", batch_steps=4, num_envs=4) @@ -340,7 +338,7 @@ def testBatchesLargerWhenVectorized(self): def testVectorEnvSupport(self): ev = PolicyEvaluator( env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="truncate_episodes", batch_steps=10) for _ in range(8): @@ -357,7 +355,7 @@ def testVectorEnvSupport(self): def testTruncateEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=15, batch_mode="truncate_episodes") batch = ev.sample() @@ -366,7 +364,7 @@ def testTruncateEpisodes(self): def testCompleteEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=5, batch_mode="complete_episodes") batch = ev.sample() @@ -375,7 +373,7 @@ def testCompleteEpisodes(self): def testCompleteEpisodesPacking(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=15, batch_mode="complete_episodes") batch = ev.sample() @@ -387,7 +385,7 @@ def testCompleteEpisodesPacking(self): def testFilterSync(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, sample_async=True, observation_filter="ConcurrentMeanStdFilter") time.sleep(2) @@ -400,7 +398,7 @@ def testFilterSync(self): def testGetFilters(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, sample_async=True, observation_filter="ConcurrentMeanStdFilter") self.sample_and_flush(ev) @@ -415,7 +413,7 @@ def testGetFilters(self): def testSyncFilter(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, sample_async=True, observation_filter="ConcurrentMeanStdFilter") obs_f = self.sample_and_flush(ev) diff --git a/python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml b/python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml new file mode 100644 index 000000000000..245e908cc89c --- /dev/null +++ b/python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml @@ -0,0 +1,12 @@ +pendulum-appo-vt: + env: Pendulum-v0 + run: APPO + stop: + episode_reward_mean: -900 # just check it learns a bit + timesteps_total: 500000 + config: + num_gpus: 0 + num_workers: 1 + gamma: 0.95 + train_batch_size: 50 + vtrace: true diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index 7aab0f2a0dfb..aad5590fd097 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -1,4 +1,5 @@ import logging +import os from ray.rllib.utils.filter_manager import FilterManager from ray.rllib.utils.filter import Filter @@ -9,13 +10,30 @@ logger = logging.getLogger(__name__) -def renamed_class(cls): +def renamed_class(cls, old_name): + """Helper class for renaming classes with a warning.""" + + class DeprecationWrapper(cls): + # note: **kw not supported for ray.remote classes + def __init__(self, *args, **kw): + new_name = cls.__module__ + "." + cls.__name__ + logger.warn("DeprecationWarning: {} has been renamed to {}. ". + format(old_name, new_name) + + "This will raise an error in the future.") + cls.__init__(self, *args, **kw) + + DeprecationWrapper.__name__ = cls.__name__ + + return DeprecationWrapper + + +def renamed_agent(cls): """Helper class for renaming Agent => Trainer with a warning.""" class DeprecationWrapper(cls): def __init__(self, config=None, env=None, logger_creator=None): old_name = cls.__name__.replace("Trainer", "Agent") - new_name = cls.__name__ + new_name = cls.__module__ + "." + cls.__name__ logger.warn("DeprecationWarning: {} has been renamed to {}. ". format(old_name, new_name) + "This will raise an error in the future.") @@ -26,6 +44,23 @@ def __init__(self, config=None, env=None, logger_creator=None): return DeprecationWrapper +def try_import_tf(): + if "RLLIB_TEST_NO_TF_IMPORT" in os.environ: + logger.warning("Not importing TensorFlow for test purposes") + return None + + try: + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() + return tf + except ImportError: + try: + import tensorflow as tf + return tf + except ImportError: + return None + + __all__ = [ "Filter", "FilterManager", @@ -34,4 +69,5 @@ def __init__(self, config=None, env=None, logger_creator=None): "merge_dicts", "deep_update", "renamed_class", + "try_import_tf", ] diff --git a/python/ray/rllib/utils/debug.py b/python/ray/rllib/utils/debug.py index ce86326f27a0..0f636b0f00ef 100644 --- a/python/ray/rllib/utils/debug.py +++ b/python/ray/rllib/utils/debug.py @@ -6,7 +6,7 @@ import pprint import time -from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch _logged = set() _disabled = False diff --git a/python/ray/rllib/utils/explained_variance.py b/python/ray/rllib/utils/explained_variance.py index 942f0f8f31f0..a3e9cbadbee3 100644 --- a/python/ray/rllib/utils/explained_variance.py +++ b/python/ray/rllib/utils/explained_variance.py @@ -2,7 +2,9 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def explained_variance(y, pred): diff --git a/python/ray/rllib/utils/seed.py b/python/ray/rllib/utils/seed.py index bec02b6ad6ec..3675fd11913d 100644 --- a/python/ray/rllib/utils/seed.py +++ b/python/ray/rllib/utils/seed.py @@ -4,7 +4,9 @@ import numpy as np import random -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def seed(np_seed=0, random_seed=0, tf_seed=0): diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py index ef411b047d7b..ed4525ddfa79 100644 --- a/python/ray/rllib/utils/tf_run_builder.py +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -6,11 +6,10 @@ import os import time -import tensorflow as tf -from tensorflow.python.client import timeline - from ray.rllib.utils.debug import log_once +from ray.rllib.utils import try_import_tf +tf = try_import_tf() logger = logging.getLogger(__name__) @@ -48,6 +47,8 @@ def get(self, to_fetch): self.session, self.fetches, self.debug_name, self.feed_dict, os.environ.get("TF_TIMELINE_DIR")) except Exception: + logger.exception("Error fetching: {}, feed_dict={}".format( + self.fetches, self.feed_dict)) raise ValueError("Error fetching: {}, feed_dict={}".format( self.fetches, self.feed_dict)) if isinstance(to_fetch, int): @@ -65,6 +66,8 @@ def get(self, to_fetch): def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): if timeline_dir: + from tensorflow.python.client import timeline + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() start = time.time() diff --git a/python/ray/rllib/utils/tracking_dict.py b/python/ray/rllib/utils/tracking_dict.py new file mode 100644 index 000000000000..c0f145734e78 --- /dev/null +++ b/python/ray/rllib/utils/tracking_dict.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class UsageTrackingDict(dict): + """Dict that tracks which keys have been accessed. + + It can also intercept gets and allow an arbitrary callback to be applied + (i.e., to lazily convert numpy arrays to Tensors). + + We make the simplifying assumption only __getitem__ is used to access + values. + """ + + def __init__(self, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self.accessed_keys = set() + self.intercepted_values = {} + self.get_interceptor = None + + def set_get_interceptor(self, fn): + self.get_interceptor = fn + + def __getitem__(self, key): + self.accessed_keys.add(key) + value = dict.__getitem__(self, key) + if self.get_interceptor: + if key not in self.intercepted_values: + self.intercepted_values[key] = self.get_interceptor(value) + value = self.intercepted_values[key] + return value diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index d870c655a4a9..5a0529a51c5d 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -380,9 +380,11 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, @cli.command() def stop(): + # Note that raylet needs to exit before object store, otherwise + # it cannot exit gracefully. processes_to_kill = [ - "plasma_store_server", "raylet", + "plasma_store_server", "raylet_monitor", "monitor.py", "redis-server", @@ -527,8 +529,8 @@ def attach(cluster_config_file, start, tmux, cluster_name, new): @cli.command() @click.argument("cluster_config_file", required=True, type=str) -@click.argument("source", required=True, type=str) -@click.argument("target", required=True, type=str) +@click.argument("source", required=False, type=str) +@click.argument("target", required=False, type=str) @click.option( "--cluster-name", "-n", @@ -541,8 +543,8 @@ def rsync_down(cluster_config_file, source, target, cluster_name): @cli.command() @click.argument("cluster_config_file", required=True, type=str) -@click.argument("source", required=True, type=str) -@click.argument("target", required=True, type=str) +@click.argument("source", required=False, type=str) +@click.argument("target", required=False, type=str) @click.option( "--cluster-name", "-n", @@ -553,7 +555,7 @@ def rsync_up(cluster_config_file, source, target, cluster_name): rsync(cluster_config_file, source, target, cluster_name, down=False) -@cli.command() +@cli.command(context_settings={"ignore_unknown_options": True}) @click.argument("cluster_config_file", required=True, type=str) @click.option( "--docker", @@ -586,14 +588,17 @@ def rsync_up(cluster_config_file, source, target, cluster_name): @click.option( "--port-forward", required=False, type=int, help="Port to forward.") @click.argument("script", required=True, type=str) -@click.argument("script_args", required=False, type=str, nargs=-1) +@click.option("--args", required=False, type=str, help="Script args.") def submit(cluster_config_file, docker, screen, tmux, stop, start, - cluster_name, port_forward, script, script_args): + cluster_name, port_forward, script, args): """Uploads and runs a script on the specified cluster. The script is automatically synced to the following location: os.path.join("~", os.path.basename(script)) + + Example: + >>> ray submit [CLUSTER.YAML] experiment.py --args="--smoke-test" """ assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." @@ -604,7 +609,10 @@ def submit(cluster_config_file, docker, screen, tmux, stop, start, target = os.path.join("~", os.path.basename(script)) rsync(cluster_config_file, script, target, cluster_name, down=False) - cmd = " ".join(["python", target] + list(script_args)) + command_parts = ["python", target] + if args is not None: + command_parts += [args] + cmd = " ".join(command_parts) exec_cluster(cluster_config_file, cmd, docker, screen, tmux, stop, False, cluster_name, port_forward) @@ -738,7 +746,7 @@ def timeline(redis_address): ray.init(redis_address=redis_address) time = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") filename = "/tmp/ray-timeline-{}.json".format(time) - ray.global_state.chrome_tracing_dump(filename=filename) + ray.timeline(filename=filename) size = os.path.getsize(filename) logger.info("Trace file written to {} ({} bytes).".format(filename, size)) logger.info( diff --git a/python/ray/services.py b/python/ray/services.py index 37ed61254848..24c96020f569 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -101,7 +101,7 @@ def get_address_info_from_redis_helper(redis_address, # Redis) must have run "CONFIG SET protected-mode no". redis_client = create_redis_client(redis_address, password=redis_password) - client_table = ray.experimental.state.parse_client_table(redis_client) + client_table = ray.state._parse_client_table(redis_client) if len(client_table) == 0: raise Exception( "Redis has started but no raylets have registered yet.") @@ -1063,6 +1063,7 @@ def start_raylet(redis_address, plasma_store_name, worker_path, temp_dir, + session_dir, num_cpus=None, num_gpus=None, resources=None, @@ -1088,6 +1089,7 @@ def start_raylet(redis_address, worker_path (str): The path of the Python file that new worker processes will execute. temp_dir (str): The path of the temporary directory Ray will use. + session_dir (str): The path of this session. num_cpus: The CPUs allocated for this raylet. num_gpus: The GPUs allocated for this raylet. resources: The custom resources allocated for this raylet. @@ -1145,7 +1147,7 @@ def start_raylet(redis_address, plasma_store_name, raylet_name, redis_password, - os.path.join(temp_dir, "sockets"), + session_dir, ) else: java_worker_command = "" @@ -1224,7 +1226,7 @@ def build_java_worker_command( plasma_store_name, raylet_name, redis_password, - temp_dir, + session_dir, ): """This method assembles the command used to start a Java worker. @@ -1235,7 +1237,7 @@ def build_java_worker_command( to. raylet_name (str): The name of the raylet socket to create. redis_password (str): The password of connect to redis. - temp_dir (str): The path of the temporary directory Ray will use. + session_dir (str): The path of this session. Returns: The command string for starting Java worker. """ @@ -1256,8 +1258,7 @@ def build_java_worker_command( command += "-Dray.redis.password={} ".format(redis_password) command += "-Dray.home={} ".format(RAY_HOME) - # TODO(suquark): We should use temp_dir as the input of a java worker. - command += "-Dray.log-dir={} ".format(os.path.join(temp_dir, "sockets")) + command += "-Dray.log-dir={} ".format(os.path.join(session_dir, "logs")) if java_worker_options: # Put `java_worker_options` in the last, so it can overwrite the @@ -1574,7 +1575,7 @@ def start_raylet_monitor(redis_address, "--config_list={}".format(config_str), ] if redis_password: - command += [redis_password] + command += ["--redis_password={}".format(redis_password)] process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_RAYLET_MONITOR, diff --git a/python/ray/experimental/state.py b/python/ray/state.py similarity index 79% rename from python/ray/experimental/state.py rename to python/ray/state.py index 31d4b77c64e6..6b2c8a4ef8bc 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/state.py @@ -4,6 +4,7 @@ from collections import defaultdict import json +import logging import sys import time @@ -13,11 +14,14 @@ from ray.ray_constants import ID_SIZE from ray import services +from ray.core.generated.EntryType import EntryType from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) +logger = logging.getLogger(__name__) -def parse_client_table(redis_client): + +def _parse_client_table(redis_client): """Read the client table. Args: @@ -54,29 +58,43 @@ def parse_client_table(redis_client): } client_id = ray.utils.binary_to_hex(client.ClientId()) - # If this client is being removed, then it must + if client.EntryType() == EntryType.INSERTION: + ordered_client_ids.append(client_id) + node_info[client_id] = { + "ClientID": client_id, + "EntryType": client.EntryType(), + "NodeManagerAddress": decode( + client.NodeManagerAddress(), allow_none=True), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName(), allow_none=True), + "RayletSocketName": decode( + client.RayletSocketName(), allow_none=True), + "Resources": resources + } + + # If this client is being updated, then it must # have previously been inserted, and # it cannot have previously been removed. - if not client.IsInsertion(): - assert client_id in node_info, "Client removed not found!" - assert node_info[client_id]["IsInsertion"], ( - "Unexpected duplicate removal of client.") else: - ordered_client_ids.append(client_id) - - node_info[client_id] = { - "ClientID": client_id, - "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), - "Resources": resources - } + assert client_id in node_info, "Client not found!" + assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( + "Unexpected updation of deleted client.") + res_map = node_info[client_id]["Resources"] + if client.EntryType() == EntryType.RES_CREATEUPDATE: + for res in resources: + res_map[res] = resources[res] + elif client.EntryType() == EntryType.RES_DELETE: + for res in resources: + res_map.pop(res, None) + elif client.EntryType() == EntryType.DELETION: + pass # Do nothing with the resmap if client deletion + else: + raise RuntimeError("Unexpected EntryType {}".format( + client.EntryType())) + node_info[client_id]["Resources"] = res_map + node_info[client_id]["EntryType"] = client.EntryType() # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -113,11 +131,11 @@ def _check_connected(self): yet. """ if self.redis_client is None: - raise Exception("The ray.global_state API cannot be used before " + raise Exception("The ray global state API cannot be used before " "ray.init has been called.") if self.redis_clients is None: - raise Exception("The ray.global_state API cannot be used before " + raise Exception("The ray global state API cannot be used before " "ray.init has been called.") def disconnect(self): @@ -393,7 +411,7 @@ def client_table(self): """ self._check_connected() - return parse_client_table(self.redis_client) + return _parse_client_table(self.redis_client) def _profile_table(self, batch_id): """Get the profile events for a given batch of profile events. @@ -446,6 +464,7 @@ def _profile_table(self, batch_id): return profile_events def profile_table(self): + self._check_connected() profile_table_keys = self._keys( ray.gcs_utils.TablePrefix_PROFILE_string + "*") batch_identifiers_binary = [ @@ -546,6 +565,8 @@ def chrome_tracing_dump(self, filename=None): # TODO(rkn): This should support viewing just a window of time or a # limited number of events. + self._check_connected() + profile_table = self.profile_table() all_events = [] @@ -611,8 +632,10 @@ def chrome_tracing_object_transfer_dump(self, filename=None): If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ + self._check_connected() + client_id_to_address = {} - for client_info in ray.global_state.client_table(): + for client_info in self.client_table(): client_id_to_address[client_info["ClientID"]] = "{}:{}".format( client_info["NodeManagerAddress"], client_info["ObjectManagerPort"]) @@ -688,6 +711,8 @@ def chrome_tracing_object_transfer_dump(self, filename=None): def workers(self): """Get a dictionary mapping worker ID to worker information.""" + self._check_connected() + worker_keys = self.redis_client.keys("Worker*") workers_data = {} @@ -708,22 +733,6 @@ def workers(self): worker_info[b"stdout_file"]) return workers_data - def actors(self): - actor_keys = self.redis_client.keys("Actor:*") - actor_info = {} - for key in actor_keys: - info = self.redis_client.hgetall(key) - actor_id = key[len("Actor:"):] - assert len(actor_id) == ID_SIZE - actor_info[binary_to_hex(actor_id)] = { - "class_id": binary_to_hex(info[b"class_id"]), - "driver_id": binary_to_hex(info[b"driver_id"]), - "raylet_id": binary_to_hex(info[b"raylet_id"]), - "num_gpus": int(info[b"num_gpus"]), - "removed": decode(info[b"removed"]) == "True" - } - return actor_info - def _job_length(self): event_log_sets = self.redis_client.keys("event_log*") overall_smallest = sys.maxsize @@ -754,21 +763,23 @@ def cluster_resources(self): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + self._check_connected() + resources = defaultdict(int) clients = self.client_table() for client in clients: - # Only count resources from live clients. - if client["IsInsertion"]: + # Only count resources from latest entries of live clients. + if client["EntryType"] != EntryType.DELETION: for key, value in client["Resources"].items(): resources[key] += value - return dict(resources) def _live_client_ids(self): """Returns a set of client IDs corresponding to clients still alive.""" return { client["ClientID"] - for client in self.client_table() if client["IsInsertion"] + for client in self.client_table() + if (client["EntryType"] != EntryType.DELETION) } def available_resources(self): @@ -783,6 +794,8 @@ def available_resources(self): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + self._check_connected() + available_resources_by_id = {} subscribe_clients = [ @@ -884,6 +897,8 @@ def error_messages(self, driver_id=None): A dictionary mapping driver ID to a list of the error messages for that driver. """ + self._check_connected() + if driver_id is not None: assert isinstance(driver_id, ray.DriverID) return self._error_messages(driver_id) @@ -939,3 +954,194 @@ def actor_checkpoint_info(self, actor_id): entry.Timestamps(i) for i in range(num_checkpoints) ], } + + +class DeprecatedGlobalState(object): + """A class used to print errors when the old global state API is used.""" + + def object_table(self, object_id=None): + logger.warning( + "ray.global_state.object_table() is deprecated and will be " + "removed in a subsequent release. Use ray.objects() instead.") + return ray.objects(object_id=object_id) + + def task_table(self, task_id=None): + logger.warning( + "ray.global_state.task_table() is deprecated and will be " + "removed in a subsequent release. Use ray.tasks() instead.") + return ray.tasks(task_id=task_id) + + def function_table(self, function_id=None): + raise DeprecationWarning( + "ray.global_state.function_table() is deprecated.") + + def client_table(self): + logger.warning( + "ray.global_state.client_table() is deprecated and will be " + "removed in a subsequent release. Use ray.nodes() instead.") + return ray.nodes() + + def profile_table(self): + raise DeprecationWarning( + "ray.global_state.profile_table() is deprecated.") + + def chrome_tracing_dump(self, filename=None): + logger.warning( + "ray.global_state.chrome_tracing_dump() is deprecated and will be " + "removed in a subsequent release. Use ray.timeline() instead.") + return ray.timeline(filename=filename) + + def chrome_tracing_object_transfer_dump(self, filename=None): + logger.warning( + "ray.global_state.chrome_tracing_object_transfer_dump() is " + "deprecated and will be removed in a subsequent release. Use " + "ray.object_transfer_timeline() instead.") + return ray.object_transfer_timeline(filename=filename) + + def workers(self): + raise DeprecationWarning("ray.global_state.workers() is deprecated.") + + def cluster_resources(self): + logger.warning( + "ray.global_state.cluster_resources() is deprecated and will be " + "removed in a subsequent release. Use ray.cluster_resources() " + "instead.") + return ray.cluster_resources() + + def available_resources(self): + logger.warning( + "ray.global_state.available_resources() is deprecated and will be " + "removed in a subsequent release. Use ray.available_resources() " + "instead.") + return ray.available_resources() + + def error_messages(self, driver_id=None): + logger.warning( + "ray.global_state.error_messages() is deprecated and will be " + "removed in a subsequent release. Use ray.errors() " + "instead.") + return ray.errors(driver_id=driver_id) + + +state = GlobalState() +"""A global object used to access the cluster's global state.""" + +global_state = DeprecatedGlobalState() + + +def nodes(): + """Get a list of the nodes in the cluster. + + Returns: + Information about the Ray clients in the cluster. + """ + return state.client_table() + + +def tasks(task_id=None): + """Fetch and parse the task table information for one or more task IDs. + + Args: + task_id: A hex string of the task ID to fetch information about. If + this is None, then the task object table is fetched. + + Returns: + Information from the task table. + """ + return state.task_table(task_id=task_id) + + +def objects(object_id=None): + """Fetch and parse the object table info for one or more object IDs. + + Args: + object_id: An object ID to fetch information about. If this is None, + then the entire object table is fetched. + + Returns: + Information from the object table. + """ + return state.object_table(object_id=object_id) + + +def timeline(filename=None): + """Return a list of profiling events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file by + passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling events. + Each profile event is a dictionary. + """ + return state.chrome_tracing_dump(filename=filename) + + +def object_transfer_timeline(filename=None): + """Return a list of transfer events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file by + passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. Make + sure to enable "Flow events" in the "View Options" menu. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling events. + Each profile event is a dictionary. + """ + return state.chrome_tracing_object_transfer_dump(filename=filename) + + +def cluster_resources(): + """Get the current total cluster resources. + + Note that this information can grow stale as nodes are added to or removed + from the cluster. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return state.cluster_resources() + + +def available_resources(): + """Get the current available cluster resources. + + This is different from `cluster_resources` in that this will return idle + (available) resources rather than total resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return state.available_resources() + + +def errors(include_cluster_errors=True): + """Get error messages from the cluster. + + Args: + include_cluster_errors: True if we should include error messages for + all drivers, and false if we should only include error messages for + this specific driver. + + Returns: + Error messages pushed from the cluster. + """ + worker = ray.worker.global_worker + error_messages = state.error_messages(driver_id=worker.task_driver_id) + if include_cluster_errors: + error_messages += state.error_messages(driver_id=ray.DriverID.nil()) + return error_messages diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 0a7984d69740..703c3a1420ed 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,6 +8,7 @@ import redis import ray +from ray.core.generated.EntryType import EntryType logger = logging.getLogger(__name__) @@ -140,7 +141,7 @@ def _wait_for_node(self, node, timeout=30): start_time = time.time() while time.time() - start_time < timeout: - clients = ray.experimental.state.parse_client_table(redis_client) + clients = ray.state._parse_client_table(redis_client) object_store_socket_names = [ client["ObjectStoreSocketName"] for client in clients ] @@ -173,9 +174,10 @@ def wait_for_nodes(self, timeout=30): start_time = time.time() while time.time() - start_time < timeout: - clients = ray.experimental.state.parse_client_table(redis_client) + clients = ray.state._parse_client_table(redis_client) live_clients = [ - client for client in clients if client["IsInsertion"] + client for client in clients + if client["EntryType"] == EntryType.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index d7da081fd18c..dd726e00f27b 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -2439,7 +2439,7 @@ def save_checkpoint(self, actor_id, checkpoint_context): assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False # Check that checkpointing errors were pushed to the driver. - errors = ray.error_info() + errors = ray.errors() assert len(errors) > 0 for error in errors: # An error for the actor process dying may also get pushed. @@ -2483,7 +2483,7 @@ def load_checkpoint(self, actor_id, checkpoints): assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False # Check that checkpointing errors were pushed to the driver. - errors = ray.error_info() + errors = ray.errors() assert len(errors) > 0 for error in errors: # An error for the actor process dying may also get pushed. diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 3f8c7cb2b3a1..50aeca025362 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor import json import logging +from multiprocessing import Process import os import random import re @@ -28,7 +29,6 @@ import ray import ray.tests.cluster_utils import ray.tests.utils -from ray.utils import _random_string logger = logging.getLogger(__name__) @@ -303,6 +303,23 @@ def f(x): assert_equal(obj, ray.get(ray.put(obj))) +def test_nested_functions(ray_start_regular): + # Make sure that remote functions can use other values that are defined + # after the remote function but before the first function invocation. + @ray.remote + def f(): + return g(), ray.get(h.remote()) + + def g(): + return 1 + + @ray.remote + def h(): + return 2 + + assert ray.get(f.remote()) == (1, 2) + + def test_ray_recursive_objects(ray_start_regular): class ClassA(object): pass @@ -918,7 +935,7 @@ def f(block, accepted_resources): stop_time = time.time() + 10 correct_available_resources = False while time.time() < stop_time: - if ray.global_state.available_resources() == { + if ray.available_resources() == { "CPU": 2.0, "GPU": 2.0, "Custom": 2.0, @@ -1159,7 +1176,7 @@ def f(): if time.time() - start_time > timeout_seconds: raise Exception("Timed out while waiting for information in " "profile table.") - profile_data = ray.global_state.chrome_tracing_dump() + profile_data = ray.timeline() event_types = {event["cat"] for event in profile_data} expected_types = [ "worker_idle", @@ -1235,7 +1252,7 @@ def f(x): # The profiling information only flushes once every second. time.sleep(1.1) - transfer_dump = ray.global_state.chrome_tracing_object_transfer_dump() + transfer_dump = ray.object_transfer_timeline() # Make sure the transfer dump can be serialized with JSON. json.loads(json.dumps(transfer_dump)) assert len(transfer_dump) >= num_nodes**2 @@ -1542,12 +1559,12 @@ def run_one_test(actors, local_only, delete_creating_tasks): # Case3: These cases test the deleting creating tasks for the object. (a, b, c) = run_one_test(actors, False, False) - task_table = ray.global_state.task_table() + task_table = ray.tasks() for obj in [a, b, c]: assert ray._raylet.compute_task_id(obj).hex() in task_table (a, b, c) = run_one_test(actors, False, True) - task_table = ray.global_state.task_table() + task_table = ray.tasks() for obj in [a, b, c]: assert ray._raylet.compute_task_id(obj).hex() not in task_table @@ -2009,7 +2026,7 @@ def run_lots_of_tasks(): results.append(run_on_0_2.remote()) return names, results - client_table = ray.global_state.client_table() + client_table = ray.nodes() store_names = [] store_names += [ client["ObjectStoreSocketName"] for client in client_table @@ -2197,13 +2214,13 @@ def test_zero_capacity_deletion_semantics(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"test_resource": 1}) def test(): - resources = ray.global_state.available_resources() + resources = ray.available_resources() MAX_RETRY_ATTEMPTS = 5 retry_count = 0 while resources and retry_count < MAX_RETRY_ATTEMPTS: time.sleep(0.1) - resources = ray.global_state.available_resources() + resources = ray.available_resources() retry_count += 1 if retry_count >= MAX_RETRY_ATTEMPTS: @@ -2377,7 +2394,7 @@ def f(x): def wait_for_num_tasks(num_tasks, timeout=10): start_time = time.time() while time.time() - start_time < timeout: - if len(ray.global_state.task_table()) >= num_tasks: + if len(ray.tasks()) >= num_tasks: return time.sleep(0.1) raise Exception("Timed out while waiting for global state.") @@ -2386,7 +2403,7 @@ def wait_for_num_tasks(num_tasks, timeout=10): def wait_for_num_objects(num_objects, timeout=10): start_time = time.time() while time.time() - start_time < timeout: - if len(ray.global_state.object_table()) >= num_objects: + if len(ray.objects()) >= num_objects: return time.sleep(0.1) raise Exception("Timed out while waiting for global state.") @@ -2397,31 +2414,27 @@ def wait_for_num_objects(num_objects, timeout=10): reason="New GCS API doesn't have a Python API yet.") def test_global_state_api(shutdown_only): with pytest.raises(Exception): - ray.global_state.object_table() - - with pytest.raises(Exception): - ray.global_state.task_table() + ray.objects() with pytest.raises(Exception): - ray.global_state.client_table() + ray.tasks() with pytest.raises(Exception): - ray.global_state.function_table() + ray.nodes() ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1}) resources = {"CPU": 5, "GPU": 3, "CustomResource": 1} - assert ray.global_state.cluster_resources() == resources + assert ray.cluster_resources() == resources - assert ray.global_state.object_table() == {} + assert ray.objects() == {} - driver_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.worker_id) + driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id) driver_task_id = ray.worker.global_worker.current_task_id.hex() # One task is put in the task table which corresponds to this driver. wait_for_num_tasks(1) - task_table = ray.global_state.task_table() + task_table = ray.tasks() assert len(task_table) == 1 assert driver_task_id == list(task_table.keys())[0] task_spec = task_table[driver_task_id]["TaskSpec"] @@ -2434,7 +2447,7 @@ def test_global_state_api(shutdown_only): assert task_spec["FunctionID"] == nil_id_hex assert task_spec["ReturnObjectIDs"] == [] - client_table = ray.global_state.client_table() + client_table = ray.nodes() node_ip_address = ray.worker.global_worker.node_ip_address assert len(client_table) == 1 @@ -2449,24 +2462,19 @@ def f(*xs): # Wait for one additional task to complete. wait_for_num_tasks(1 + 1) - task_table = ray.global_state.task_table() + task_table = ray.tasks() assert len(task_table) == 1 + 1 task_id_set = set(task_table.keys()) task_id_set.remove(driver_task_id) task_id = list(task_id_set)[0] - function_table = ray.global_state.function_table() task_spec = task_table[task_id]["TaskSpec"] assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [1, "hi", x_id] assert task_spec["DriverID"] == driver_id assert task_spec["ReturnObjectIDs"] == [result_id] - function_table_entry = function_table[task_spec["FunctionID"]] - assert function_table_entry["Name"] == "ray.tests.test_basic.f" - assert function_table_entry["DriverID"] == driver_id - assert function_table_entry["Module"] == "ray.tests.test_basic" - assert task_table[task_id] == ray.global_state.task_table(task_id) + assert task_table[task_id] == ray.tasks(task_id) # Wait for two objects, one for the x_id and one for result_id. wait_for_num_objects(2) @@ -2475,7 +2483,7 @@ def wait_for_object_table(): timeout = 10 start_time = time.time() while time.time() - start_time < timeout: - object_table = ray.global_state.object_table() + object_table = ray.objects() tables_ready = (object_table[x_id]["ManagerIDs"] is not None and object_table[result_id]["ManagerIDs"] is not None) if tables_ready: @@ -2484,11 +2492,11 @@ def wait_for_object_table(): raise Exception("Timed out while waiting for object table to " "update.") - object_table = ray.global_state.object_table() + object_table = ray.objects() assert len(object_table) == 2 - assert object_table[x_id] == ray.global_state.object_table(x_id) - object_table_entry = ray.global_state.object_table(result_id) + assert object_table[x_id] == ray.objects(x_id) + object_table_entry = ray.objects(result_id) assert object_table[result_id] == object_table_entry @@ -2594,14 +2602,6 @@ def f(): while len(worker_ids) != num_workers: worker_ids = set(ray.get([f.remote() for _ in range(10)])) - worker_info = ray.global_state.workers() - assert len(worker_info) >= num_workers - for worker_id, info in worker_info.items(): - assert "node_ip_address" in info - assert "plasma_store_socket" in info - assert "stderr_file" in info - assert "stdout_file" in info - def test_specific_driver_id(): dummy_driver_id = ray.DriverID(b"00112233445566778899") @@ -2630,12 +2630,33 @@ def test_object_id_properties(): ray.ObjectID(id_bytes + b"1234") with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectID(b"0123456789") - object_id = ray.ObjectID(_random_string()) + object_id = ray.ObjectID.from_random() assert not object_id.is_nil() assert object_id.binary() != id_bytes id_dumps = pickle.dumps(object_id) id_from_dumps = pickle.loads(id_dumps) assert id_from_dumps == object_id + file_prefix = "test_object_id_properties" + + # Make sure the ids are fork safe. + def write(index): + str = ray.ObjectID.from_random().hex() + with open("{}{}".format(file_prefix, index), "w") as fo: + fo.write(str) + + def read(index): + with open("{}{}".format(file_prefix, index), "r") as fi: + for line in fi: + return line + + processes = [Process(target=write, args=(_, )) for _ in range(4)] + for process in processes: + process.start() + for process in processes: + process.join() + hexes = {read(i) for i in range(4)} + [os.remove("{}{}".format(file_prefix, i)) for i in range(4)] + assert len(hexes) == 4 @pytest.fixture @@ -2768,7 +2789,7 @@ def test_pandas_parquet_serialization(): def test_socket_dir_not_existing(shutdown_only): - random_name = ray.ObjectID(_random_string()).hex() + random_name = ray.ObjectID.from_random().hex() temp_raylet_socket_dir = "/tmp/ray/tests/{}".format(random_name) temp_raylet_socket_name = os.path.join(temp_raylet_socket_dir, "raylet_socket") @@ -2778,7 +2799,7 @@ def test_socket_dir_not_existing(shutdown_only): def test_raylet_is_robust_to_random_messages(ray_start_regular): node_manager_address = None node_manager_port = None - for client in ray.global_state.client_table(): + for client in ray.nodes(): if "NodeManagerAddress" in client: node_manager_address = client["NodeManagerAddress"] node_manager_port = client["NodeManagerPort"] @@ -2870,7 +2891,7 @@ def test_shutdown_disconnect_global_state(): ray.shutdown() with pytest.raises(Exception) as e: - ray.global_state.object_table() + ray.objects() assert str(e.value).endswith("ray.init has been called.") @@ -2884,8 +2905,8 @@ def test_redis_lru_with_set(ray_start_object_store_memory): removed = False start_time = time.time() while time.time() < start_time + 10: - if ray.global_state.redis_clients[0].delete(b"OBJECT" + - x_id.binary()) == 1: + if ray.state.state.redis_clients[0].delete(b"OBJECT" + + x_id.binary()) == 1: removed = True break assert removed @@ -2921,3 +2942,43 @@ def get_postprocessor(object_ids, values): assert ray.get( [ray.put(i) for i in [0, 1, 3, 5, -1, -3, 4]]) == [1, 3, 5, 4] + + +def test_export_after_shutdown(ray_start_regular): + # This test checks that we can use actor and remote function definitions + # across multiple Ray sessions. + + @ray.remote + def f(): + pass + + @ray.remote + class Actor(object): + def method(self): + pass + + ray.get(f.remote()) + a = Actor.remote() + ray.get(a.method.remote()) + + ray.shutdown() + + # Start Ray and use the remote function and actor again. + ray.init(num_cpus=1) + ray.get(f.remote()) + a = Actor.remote() + ray.get(a.method.remote()) + + ray.shutdown() + + # Start Ray again and make sure that these definitions can be exported from + # workers. + ray.init(num_cpus=2) + + @ray.remote + def export_definitions_from_worker(remote_function, actor_class): + ray.get(remote_function.remote()) + actor_handle = actor_class.remote() + ray.get(actor_handle.method.remote()) + + ray.get(export_definitions_from_worker.remote(f, Actor)) diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py new file mode 100644 index 000000000000..ea647adf1207 --- /dev/null +++ b/python/ray/tests/test_dynres.py @@ -0,0 +1,571 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import time + +import ray +import ray.tests.cluster_utils +import ray.tests.utils + +logger = logging.getLogger(__name__) + + +def test_dynamic_res_creation(ray_start_regular): + # This test creates a resource locally (without specifying the client_id) + res_name = "test_res" + res_capacity = 1.0 + + @ray.remote + def set_res(resource_name, resource_capacity): + ray.experimental.set_resource(resource_name, resource_capacity) + + ray.get(set_res.remote(res_name, res_capacity)) + + available_res = ray.available_resources() + cluster_res = ray.cluster_resources() + + assert available_res[res_name] == res_capacity + assert cluster_res[res_name] == res_capacity + + +def test_dynamic_res_deletion(shutdown_only): + # This test deletes a resource locally (without specifying the client_id) + res_name = "test_res" + res_capacity = 1.0 + + ray.init(num_cpus=1, resources={res_name: res_capacity}) + + @ray.remote + def delete_res(resource_name): + ray.experimental.set_resource(resource_name, 0) + + ray.get(delete_res.remote(res_name)) + + available_res = ray.available_resources() + cluster_res = ray.cluster_resources() + + assert res_name not in available_res + assert res_name not in cluster_res + + +def test_dynamic_res_infeasible_rescheduling(ray_start_regular): + # This test launches an infeasible task and then creates a + # resource to make the task feasible. This tests if the + # infeasible tasks get rescheduled when resources are + # created at runtime. + res_name = "test_res" + res_capacity = 1.0 + + @ray.remote + def set_res(resource_name, resource_capacity): + ray.experimental.set_resource(resource_name, resource_capacity) + + def f(): + return 1 + + remote_task = ray.remote(resources={res_name: res_capacity})(f) + oid = remote_task.remote() # This is infeasible + ray.get(set_res.remote(res_name, res_capacity)) # Now should be feasible + + available_res = ray.available_resources() + assert available_res[res_name] == res_capacity + + successful, unsuccessful = ray.wait([oid], timeout=1) + assert successful # The task completed + + +def test_dynamic_res_updation_clientid(ray_start_cluster): + # This test does a simple resource capacity update + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.nodes()[1]["ClientID"] + + @ray.remote + def set_res(resource_name, resource_capacity, client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=client_id) + + # Create resource + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + + # Update resource + new_capacity = res_capacity + 1 + ray.get(set_res.remote(res_name, new_capacity, target_clientid)) + + target_client = next(client for client in ray.nodes() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + + assert res_name in resources + assert resources[res_name] == new_capacity + + +def test_dynamic_res_creation_clientid(ray_start_cluster): + # Creates a resource on a specific client and verifies creation. + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.nodes()[1]["ClientID"] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + target_client = next(client for client in ray.nodes() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + + assert res_name in resources + assert resources[res_name] == res_capacity + + +def test_dynamic_res_creation_clientid_multiple(ray_start_cluster): + # This test creates resources on multiple clients using the clientid + # specifier + cluster = ray_start_cluster + + TIMEOUT = 5 + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientids = [client["ClientID"] for client in ray.nodes()] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + results = [] + for cid in target_clientids: + results.append(set_res.remote(res_name, res_capacity, cid)) + ray.get(results) + + success = False + start_time = time.time() + + while time.time() - start_time < TIMEOUT and not success: + resources_created = [] + for cid in target_clientids: + target_client = next( + client for client in ray.nodes() if client["ClientID"] == cid) + resources = target_client["Resources"] + resources_created.append(resources[res_name] == res_capacity) + success = all(resources_created) + assert success + + +def test_dynamic_res_deletion_clientid(ray_start_cluster): + # This test deletes a resource on a given client id + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + + for i in range(num_nodes): + # Create resource on all nodes, but later we'll delete it from a + # target node + cluster.add_node(resources={res_name: res_capacity}) + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.nodes()[1]["ClientID"] + + # Launch the delete task + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + ray.get(delete_res.remote(res_name, target_clientid)) + + target_client = next(client for client in ray.nodes() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + print(ray.cluster_resources()) + assert res_name not in resources + + +def test_dynamic_res_creation_scheduler_consistency(ray_start_cluster): + # This makes sure the resource is actually created and the state is + # consistent in the scheduler + # by launching a task which requests the created resource + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [client["ClientID"] for client in ray.nodes()] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node1 + target_clientid = clientids[1] + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + + # Define a task which requires this resource + @ray.remote(resources={res_name: res_capacity}) + def test_func(): + return 1 + + result = test_func.remote() + successful, unsuccessful = ray.wait([result], timeout=5) + assert successful # The task completed + + +def test_dynamic_res_deletion_scheduler_consistency(ray_start_cluster): + # This makes sure the resource is actually deleted and the state is + # consistent in the scheduler by launching an infeasible task which + # requests the created resource + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [client["ClientID"] for client in ray.nodes()] + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node1 + target_clientid = clientids[1] + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.cluster_resources()[res_name] == res_capacity + + # Delete the resource + ray.get(delete_res.remote(res_name, target_clientid)) + + # Define a task which requires this resource. This should not run + @ray.remote(resources={res_name: res_capacity}) + def test_func(): + return 1 + + result = test_func.remote() + successful, unsuccessful = ray.wait([result], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + + +def test_dynamic_res_concurrent_res_increment(ray_start_cluster): + # This test makes sure resource capacity is updated (increment) correctly + # when a task has already acquired some of the resource. + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + updated_capacity = 10 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [client["ClientID"] for client in ray.nodes()] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Update the resource capacity + ray.get(set_res.remote(res_name, updated_capacity, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + task_2 = test_func._remote(args=[], resources={res_name: updated_capacity}) + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert successful # The task completed + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + 1. This should not execute + task_3 = test_func._remote( + args=[], resources={res_name: updated_capacity + 1 + }) # This should be infeasible + successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert ray.available_resources()[res_name] == updated_capacity + + +def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): + # This test makes sure resource capacity is updated (decremented) + # correctly when a task has already acquired some + # of the resource. + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + updated_capacity = 2 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [client["ClientID"] for client in ray.nodes()] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Decrease the resource capacity + ray.get(set_res.remote(res_name, updated_capacity, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + task_2 = test_func._remote(args=[], resources={res_name: updated_capacity}) + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert successful # The task completed + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + 1. This should not execute + task_3 = test_func._remote( + args=[], resources={res_name: updated_capacity + 1 + }) # This should be infeasible + successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert ray.available_resources()[res_name] == updated_capacity + + +def test_dynamic_res_concurrent_res_delete(ray_start_cluster): + # This test makes sure resource gets deleted correctly when a task has + # already acquired the resource + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [client["ClientID"] for client in ray.nodes()] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Delete the resource + ray.get(delete_res.remote(res_name, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # the deleted resource This should not execute + task_2 = test_func._remote( + args=[], resources={res_name: 1}) # This should be infeasible + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert res_name not in ray.available_resources() + + +def test_dynamic_res_creation_stress(ray_start_cluster): + # This stress tests creates many resources simultaneously on the same + # client and then checks if the final state is consistent + + cluster = ray_start_cluster + + TIMEOUT = 5 + res_capacity = 1 + num_nodes = 5 + NUM_RES_TO_CREATE = 500 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [client["ClientID"] for client in ray.nodes()] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + results = [ + set_res.remote(str(i), res_capacity, target_clientid) + for i in range(0, NUM_RES_TO_CREATE) + ] + ray.get(results) + + success = False + start_time = time.time() + + while time.time() - start_time < TIMEOUT and not success: + resources = ray.cluster_resources() + all_resources_created = [] + for i in range(0, NUM_RES_TO_CREATE): + all_resources_created.append(str(i) in resources) + success = all(all_resources_created) + assert success diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 8fb58e576ea1..51b906695c2d 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -15,7 +15,6 @@ import ray import ray.ray_constants as ray_constants -from ray.utils import _random_string from ray.tests.cluster_utils import Cluster from ray.tests.utils import ( relevant_errors, @@ -96,7 +95,15 @@ def temporary_helper_function(): # fail when it is unpickled. @ray.remote def g(): - return module.temporary_python_file() + try: + module.temporary_python_file() + except Exception: + # This test is not concerned with the error from running this + # function. Only from unpickling the remote function. + pass + + # Invoke the function so that the definition is exported. + g.remote() wait_for_errors(ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, 2) errors = relevant_errors(ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR) @@ -157,7 +164,7 @@ def get_val(self): return 1 # There should be no errors yet. - assert len(ray.error_info()) == 0 + assert len(ray.errors()) == 0 # Create an actor. foo = Foo.remote() @@ -369,8 +376,9 @@ class Actor(object): a = Actor.remote() a.__ray_terminate__.remote() time.sleep(1) - assert len(ray.error_info()) == 0, ( - "Should not have propogated an error - {}".format(ray.error_info())) + assert len( + ray.errors()) == 0, ("Should not have propogated an error - {}".format( + ray.errors())) @pytest.mark.skip("This test does not work yet.") @@ -500,6 +508,9 @@ def test_export_large_objects(ray_start_regular): def f(): large_object + # Invoke the function so that the definition is exported. + f.remote() + # Make sure that a warning is generated. wait_for_errors(ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, 1) @@ -643,7 +654,7 @@ def test_warning_for_dead_node(ray_start_cluster_2_nodes): cluster = ray_start_cluster_2_nodes cluster.wait_for_nodes() - client_ids = {item["ClientID"] for item in ray.global_state.client_table()} + client_ids = {item["ClientID"] for item in ray.nodes()} # Try to make sure that the monitor has received at least one heartbeat # from the node. @@ -667,7 +678,7 @@ def test_warning_for_dead_node(ray_start_cluster_2_nodes): def test_raylet_crash_when_get(ray_start_regular): - nonexistent_id = ray.ObjectID(_random_string()) + nonexistent_id = ray.ObjectID.from_random() def sleep_to_kill_raylet(): # Don't kill raylet before default workers get connected. diff --git a/python/ray/tests/test_global_state.py b/python/ray/tests/test_global_state.py index bc82eb8590c3..db71fc69c73b 100644 --- a/python/ray/tests/test_global_state.py +++ b/python/ray/tests/test_global_state.py @@ -18,8 +18,8 @@ reason="Timeout package not installed; skipping test that may hang.") @pytest.mark.timeout(10) def test_replenish_resources(ray_start_regular): - cluster_resources = ray.global_state.cluster_resources() - available_resources = ray.global_state.available_resources() + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() assert cluster_resources == available_resources @ray.remote @@ -30,7 +30,7 @@ def cpu_task(): resources_reset = False while not resources_reset: - available_resources = ray.global_state.available_resources() + available_resources = ray.available_resources() resources_reset = (cluster_resources == available_resources) assert resources_reset @@ -40,7 +40,7 @@ def cpu_task(): reason="Timeout package not installed; skipping test that may hang.") @pytest.mark.timeout(10) def test_uses_resources(ray_start_regular): - cluster_resources = ray.global_state.cluster_resources() + cluster_resources = ray.cluster_resources() @ray.remote def cpu_task(): @@ -50,7 +50,7 @@ def cpu_task(): resource_used = False while not resource_used: - available_resources = ray.global_state.available_resources() + available_resources = ray.available_resources() resource_used = available_resources.get( "CPU", 0) == cluster_resources.get("CPU", 0) - 1 @@ -64,17 +64,17 @@ def cpu_task(): def test_add_remove_cluster_resources(ray_start_cluster_head): """Tests that Global State API is consistent with actual cluster.""" cluster = ray_start_cluster_head - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 nodes = [] nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 2 + assert ray.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 for i in range(5): nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 6 + assert ray.cluster_resources()["CPU"] == 6 diff --git a/python/ray/tests/test_monitors.py b/python/ray/tests/test_monitors.py index d588732c11f4..9eebe7e45087 100644 --- a/python/ray/tests/test_monitors.py +++ b/python/ray/tests/test_monitors.py @@ -30,29 +30,21 @@ def _test_cleanup_on_driver_exit(num_redis_shards): time.sleep(2) def StateSummary(): - obj_tbl_len = len(ray.global_state.object_table()) - task_tbl_len = len(ray.global_state.task_table()) - func_tbl_len = len(ray.global_state.function_table()) - return obj_tbl_len, task_tbl_len, func_tbl_len + obj_tbl_len = len(ray.objects()) + task_tbl_len = len(ray.tasks()) + return obj_tbl_len, task_tbl_len def Driver(success): success.value = True # Start driver. ray.init(redis_address=redis_address) summary_start = StateSummary() - if (0, 1) != summary_start[:2]: + if (0, 1) != summary_start: success.value = False # Two new objects. ray.get(ray.put(1111)) ray.get(ray.put(1111)) - attempts = 0 - while (2, 1, summary_start[2]) != StateSummary(): - time.sleep(0.1) - attempts += 1 - if attempts == max_attempts_before_failing: - success.value = False - break @ray.remote def f(): @@ -61,7 +53,7 @@ def f(): # 1 new function. attempts = 0 - while (2, 1, summary_start[2] + 1) != StateSummary(): + while (2, 1) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: @@ -70,7 +62,7 @@ def f(): ray.get(f.remote()) attempts = 0 - while (4, 2, summary_start[2] + 1) != StateSummary(): + while (4, 2) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: @@ -90,12 +82,12 @@ def f(): # Check that objects, tasks, and functions are cleaned up. ray.init(redis_address=redis_address) attempts = 0 - while (0, 1) != StateSummary()[:2]: + while (0, 1) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: break - assert (0, 1) == StateSummary()[:2] + assert (0, 1) == StateSummary() ray.shutdown() subprocess.Popen(["ray", "stop"]).wait() diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index a963f6b15ea1..07f0d621c483 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -19,7 +19,7 @@ def test_error_isolation(call_ray_start): ray.init(redis_address=redis_address) # There shouldn't be any errors yet. - assert len(ray.error_info()) == 0 + assert len(ray.errors()) == 0 error_string1 = "error_string1" error_string2 = "error_string2" @@ -33,13 +33,13 @@ def f(): ray.get(f.remote()) # Wait for the error to appear in Redis. - while len(ray.error_info()) != 1: + while len(ray.errors()) != 1: time.sleep(0.1) print("Waiting for error to appear.") # Make sure we got the error. - assert len(ray.error_info()) == 1 - assert error_string1 in ray.error_info()[0]["message"] + assert len(ray.errors()) == 1 + assert error_string1 in ray.errors()[0]["message"] # Start another driver and make sure that it does not receive this # error. Make the other driver throw an error, and make sure it @@ -51,7 +51,7 @@ def f(): ray.init(redis_address="{}") time.sleep(1) -assert len(ray.error_info()) == 0 +assert len(ray.errors()) == 0 @ray.remote def f(): @@ -62,12 +62,12 @@ def f(): except Exception as e: pass -while len(ray.error_info()) != 1: - print(len(ray.error_info())) +while len(ray.errors()) != 1: + print(len(ray.errors())) time.sleep(0.1) -assert len(ray.error_info()) == 1 +assert len(ray.errors()) == 1 -assert "{}" in ray.error_info()[0]["message"] +assert "{}" in ray.errors()[0]["message"] print("success") """.format(redis_address, error_string2, error_string2) @@ -78,8 +78,8 @@ def f(): # Make sure that the other error message doesn't show up for this # driver. - assert len(ray.error_info()) == 1 - assert error_string1 in ray.error_info()[0]["message"] + assert len(ray.errors()) == 1 + assert error_string1 in ray.errors()[0]["message"] def test_remote_function_isolation(call_ray_start): diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index e66a3799e25e..979f4728330f 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -52,10 +52,10 @@ def test_internal_config(ray_start_cluster_head): cluster.remove_node(worker) time.sleep(1) - assert ray.global_state.cluster_resources()["CPU"] == 2 + assert ray.cluster_resources()["CPU"] == 2 time.sleep(2) - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 def test_wait_for_nodes(ray_start_cluster_head): @@ -70,12 +70,12 @@ def test_wait_for_nodes(ray_start_cluster_head): [cluster.remove_node(w) for w in workers] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 worker2 = cluster.add_node() cluster.wait_for_nodes() cluster.remove_node(worker2) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 def test_worker_plasma_store_failure(ray_start_cluster_head): diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index e02e3d9a7d6e..bbe47a7e47d0 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -80,7 +80,7 @@ def create_object(): # Wait for profiling information to be pushed to the profile table. time.sleep(1) - transfer_events = ray.global_state.chrome_tracing_object_transfer_dump() + transfer_events = ray.object_transfer_timeline() # Make sure that each object was transferred a reasonable number of times. for x_id in object_ids: @@ -160,7 +160,7 @@ def set_weights(self, x): # Wait for profiling information to be pushed to the profile table. time.sleep(1) - transfer_events = ray.global_state.chrome_tracing_object_transfer_dump() + transfer_events = ray.object_transfer_timeline() # Make sure that each object was transferred a reasonable number of times. for x_id in object_ids: diff --git a/python/ray/tests/test_stress.py b/python/ray/tests/test_stress.py index 4f94e2310b7c..1135d71011bf 100644 --- a/python/ray/tests/test_stress.py +++ b/python/ray/tests/test_stress.py @@ -393,7 +393,7 @@ def wait_for_errors(error_check): errors = [] time_left = 100 while time_left > 0: - errors = ray.error_info() + errors = ray.errors() if error_check(errors): break time_left -= 1 diff --git a/python/ray/tests/utils.py b/python/ray/tests/utils.py index 22146e89fa65..bd9291d8fa81 100644 --- a/python/ray/tests/utils.py +++ b/python/ray/tests/utils.py @@ -84,7 +84,7 @@ def run_string_as_driver_nonblocking(driver_script): def relevant_errors(error_type): - return [info for info in ray.error_info() if info["type"] == error_type] + return [info for info in ray.errors() if info["type"] == error_type] def wait_for_errors(error_type, num_errors, timeout=10): diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 810256e07138..560a67e6b35b 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -14,5 +14,5 @@ __all__ = [ "Trainable", "TuneError", "grid_search", "register_env", "register_trainable", "run", "run_experiments", "Experiment", "function", - "sample_from", "uniform", "choice", "randint", "randn" + "sample_from", "track", "uniform", "choice", "randint", "randn" ] diff --git a/python/ray/tune/automlboard/backend/collector.py b/python/ray/tune/automlboard/backend/collector.py index 5566f479960f..dd87df1a450d 100644 --- a/python/ray/tune/automlboard/backend/collector.py +++ b/python/ray/tune/automlboard/backend/collector.py @@ -14,7 +14,7 @@ from ray.tune.automlboard.models.models import JobRecord, \ TrialRecord, ResultRecord from ray.tune.result import DEFAULT_RESULTS_DIR, JOB_META_FILE, \ - EXPR_PARARM_FILE, EXPR_RESULT_FILE, EXPR_META_FILE + EXPR_PARAM_FILE, EXPR_RESULT_FILE, EXPR_META_FILE class CollectorService(object): @@ -327,7 +327,7 @@ def _build_trial_meta(cls, expr_dir): if not meta: job_id = expr_dir.split("/")[-2] trial_id = expr_dir[-8:] - params = parse_json(os.path.join(expr_dir, EXPR_PARARM_FILE)) + params = parse_json(os.path.join(expr_dir, EXPR_PARAM_FILE)) meta = { "trial_id": trial_id, "job_id": job_id, @@ -349,7 +349,7 @@ def _build_trial_meta(cls, expr_dir): if meta.get("end_time", None): meta["end_time"] = timestamp2date(meta["end_time"]) - meta["params"] = parse_json(os.path.join(expr_dir, EXPR_PARARM_FILE)) + meta["params"] = parse_json(os.path.join(expr_dir, EXPR_PARAM_FILE)) return meta diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 864ed6402639..139ef6f82bc3 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -76,22 +76,6 @@ def make_parser(parser_creator=None, **kwargs): default="", type=str, help="Optional URI to sync training results to (e.g. s3://bucket).") - parser.add_argument( - "--trial-name-creator", - default=None, - help="Optional creator function for the trial string, used in " - "generating a trial directory.") - parser.add_argument( - "--sync-function", - default=None, - help="Function for syncing the local_dir to upload_dir. If string, " - "then it must be a string template for syncer to run and needs to " - "include replacement fields '{local_dir}' and '{remote_dir}'.") - parser.add_argument( - "--loggers", - default=None, - help="List of logger creators to be used with each Trial. " - "Defaults to ray.tune.logger.DEFAULT_LOGGERS.") parser.add_argument( "--checkpoint-freq", default=0, @@ -187,7 +171,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): A trial object with corresponding parameters to the specification. """ try: - args = parser.parse_args(to_argv(spec)) + args, _ = parser.parse_known_args(to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) if "resources_per_trial" in spec: diff --git a/python/ray/tune/examples/ax_example.py b/python/ray/tune/examples/ax_example.py index 07bb7f79a1f3..8620986a26ea 100644 --- a/python/ray/tune/examples/ax_example.py +++ b/python/ray/tune/examples/ax_example.py @@ -51,11 +51,13 @@ def easy_objective(config, reporter): if __name__ == "__main__": import argparse + from ax.service.ax_client import AxClient parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() + ray.init() config = { @@ -101,13 +103,14 @@ def easy_objective(config, reporter): "bounds": [0.0, 1.0], }, ] - algo = AxSearch( + client = AxClient(enforce_sequential_optimization=False) + client.create_experiment( parameters=parameters, objective_name="hartmann6", - max_concurrent=4, minimize=True, # Optional, defaults to False. parameter_constraints=["x1 + x2 <= 2.0"], # Optional. outcome_constraints=["l2norm <= 1.25"], # Optional. ) + algo = AxSearch(client, max_concurrent=4) scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6") run(easy_objective, name="ax", search_alg=algo, **config) diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index ac26d0353a98..7163dcfd6a01 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -49,6 +49,11 @@ action="store_true", default=False, help="disables CUDA training") +parser.add_argument( + "--redis-address", + default=None, + type=str, + help="The Redis address of the cluster.") parser.add_argument( "--seed", type=int, @@ -173,7 +178,7 @@ def _restore(self, checkpoint_path): from ray import tune from ray.tune.schedulers import HyperBandScheduler - ray.init() + ray.init(redis_address=args.redis_address) sched = HyperBandScheduler( time_attr="training_iteration", reward_attr="neg_mean_loss") tune.run( diff --git a/python/ray/tune/examples/track_example.py b/python/ray/tune/examples/track_example.py new file mode 100644 index 000000000000..1ccec39462d0 --- /dev/null +++ b/python/ray/tune/examples/track_example.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import keras +from keras.datasets import mnist +from keras.models import Sequential +from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) + +from ray.tune import track +from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data + +parser = argparse.ArgumentParser() +parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") +parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)") +parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)") +parser.add_argument( + "--hidden", type=int, default=64, help="Size of hidden layer.") +args, _ = parser.parse_known_args() + + +def train_mnist(args): + track.init(trial_name="track-example", trial_config=vars(args)) + batch_size = 128 + num_classes = 10 + epochs = 1 if args.smoke_test else 12 + mnist.load() + x_train, y_train, x_test, y_test, input_shape = get_mnist_data() + + model = Sequential() + model.add( + Conv2D( + 32, kernel_size=(3, 3), activation="relu", + input_shape=input_shape)) + model.add(Conv2D(64, (3, 3), activation="relu")) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Dropout(0.5)) + model.add(Flatten()) + model.add(Dense(args.hidden, activation="relu")) + model.add(Dropout(0.5)) + model.add(Dense(num_classes, activation="softmax")) + + model.compile( + loss="categorical_crossentropy", + optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum), + metrics=["accuracy"]) + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + validation_data=(x_test, y_test), + callbacks=[TuneKerasCallback(track.metric)]) + track.shutdown() + + +if __name__ == "__main__": + train_mnist(args) diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py index 3c73bce2bae7..a5ab1dbdb6a1 100644 --- a/python/ray/tune/examples/utils.py +++ b/python/ray/tune/examples/utils.py @@ -15,7 +15,9 @@ def __init__(self, reporter, logs={}): def on_train_end(self, epoch, logs={}): self.reporter( - timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"]) + timesteps_total=self.iteration, + done=1, + mean_accuracy=logs.get("acc")) def on_batch_end(self, batch, logs={}): self.iteration += 1 diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 551f1702759f..e30e2bdf5cf0 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -5,9 +5,11 @@ import logging import sys import time +import inspect import threading from six.moves import queue +from ray.tune import track from ray.tune import TuneError from ray.tune.trainable import Trainable from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE @@ -244,6 +246,17 @@ def _report_thread_runner_error(self, block=False): def wrap_function(train_func): + + use_track = False + try: + func_args = inspect.getargspec(train_func).args + use_track = ("reporter" not in func_args and len(func_args) == 1) + if use_track: + logger.info("tune.track signature detected.") + except Exception: + logger.info( + "Function inspection failed - assuming reporter signature.") + class WrappedFunc(FunctionRunner): def _trainable_func(self, config, reporter): output = train_func(config, reporter) @@ -253,4 +266,12 @@ def _trainable_func(self, config, reporter): reporter(**{RESULT_DUPLICATE: True}) return output - return WrappedFunc + class WrappedTrackFunc(FunctionRunner): + def _trainable_func(self, config, reporter): + track.init(_tune_reporter=reporter) + output = train_func(config) + reporter(**{RESULT_DUPLICATE: True}) + track.shutdown() + return output + + return WrappedTrackFunc if use_track else WrappedFunc diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 1f6ddd3d30f9..9b5a4830ad38 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -51,6 +51,11 @@ def on_result(self, result): raise NotImplementedError + def update_config(self, config): + """Updates the config for all loggers.""" + + pass + def close(self): """Releases all resources used by this logger.""" @@ -69,17 +74,7 @@ def on_result(self, result): class JsonLogger(Logger): def _init(self): - config_out = os.path.join(self.logdir, "params.json") - with open(config_out, "w") as f: - json.dump( - self.config, - f, - indent=2, - sort_keys=True, - cls=_SafeFallbackEncoder) - config_pkl = os.path.join(self.logdir, "params.pkl") - with open(config_pkl, "wb") as f: - cloudpickle.dump(self.config, f) + self.update_config(self.config) local_file = os.path.join(self.logdir, "result.json") self.local_out = open(local_file, "a") @@ -97,6 +92,15 @@ def flush(self): def close(self): self.local_out.close() + def update_config(self, config): + self.config = config + config_out = os.path.join(self.logdir, "params.json") + with open(config_out, "w") as f: + json.dump(self.config, f, cls=_SafeFallbackEncoder) + config_pkl = os.path.join(self.logdir, "params.pkl") + with open(config_pkl, "wb") as f: + cloudpickle.dump(self.config, f) + def to_tf_values(result, path): values = [] @@ -119,10 +123,14 @@ class TFLogger(Logger): def _init(self): try: global tf, use_tf150_api - import tensorflow - tf = tensorflow - use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.5.0")) + if "RLLIB_TEST_NO_TF_IMPORT" in os.environ: + logger.warning("Not importing TensorFlow for test purposes") + tf = None + else: + import tensorflow + tf = tensorflow + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) except ImportError: logger.warning("Couldn't import TensorFlow - " "disabling TensorBoard logging.") @@ -228,6 +236,10 @@ def on_result(self, result): self._log_syncer.set_worker_ip(result.get(NODE_IP)) self._log_syncer.sync_if_needed() + def update_config(self, config): + for _logger in self._loggers: + _logger.update_config(config) + def close(self): for _logger in self._loggers: _logger.close() diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index b27eeb1f1ddb..bfd9e0ad0d29 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -356,7 +356,7 @@ def _return_resources(self, resources): def _update_avail_resources(self, num_retries=5): for i in range(num_retries): try: - resources = ray.global_state.cluster_resources() + resources = ray.cluster_resources() except Exception: # TODO(rliaw): Remove this when local mode is fixed. # https://github.com/ray-project/ray/issues/4147 diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 2978fe540d18..51a67d5931a7 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -68,7 +68,7 @@ EXPR_META_FILE = "trial_status.json" # File that stores parameters of the trial. -EXPR_PARARM_FILE = "params.json" +EXPR_PARAM_FILE = "params.json" # File that stores the progress of the trial. EXPR_PROGRESS_FILE = "progress.csv" diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py index a48852e84864..75b982d67087 100644 --- a/python/ray/tune/suggest/ax.py +++ b/python/ray/tune/suggest/ax.py @@ -6,16 +6,19 @@ import ax except ImportError: ax = None +import logging from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class AxSearch(SuggestionAlgorithm): """A wrapper around Ax to provide trial suggestions. - Requires Ax to be installed. - Ax is an open source tool from Facebook for configuring and - optimizing experiments. More information can be found in https://ax.dev/. + Requires Ax to be installed. Ax is an open source tool from + Facebook for configuring and optimizing experiments. More information + can be found in https://ax.dev/. Parameters: parameters (list[dict]): Parameters in the experiment search space. @@ -48,40 +51,27 @@ class AxSearch(SuggestionAlgorithm): >>> objective_name="hartmann6", max_concurrent=4) """ - def __init__(self, - parameters, - objective_name, - max_concurrent=10, - minimize=False, - parameter_constraints=None, - outcome_constraints=None, - **kwargs): + def __init__(self, ax_client, max_concurrent=10, **kwargs): assert ax is not None, "Ax must be installed!" - from ax.service import ax_client assert type(max_concurrent) is int and max_concurrent > 0 - self._ax = ax_client.AxClient(enforce_sequential_optimization=False) - self._ax.create_experiment( - name="ax", - parameters=parameters, - objective_name=objective_name, - minimize=minimize, - parameter_constraints=parameter_constraints or [], - outcome_constraints=outcome_constraints or [], - ) + self._ax = ax_client + exp = self._ax.experiment + self._objective_name = exp.optimization_config.objective.metric.name + if self._ax._enforce_sequential_optimization: + logger.warning("Detected sequential enforcement. Setting max " + "concurrency to 1.") + max_concurrent = 1 self._max_concurrent = max_concurrent - self._parameters = [d["name"] for d in parameters] - self._objective_name = objective_name + self._parameters = list(exp.parameters) self._live_index_mapping = {} - super(AxSearch, self).__init__(**kwargs) def _suggest(self, trial_id): if self._num_live_trials() >= self._max_concurrent: return None parameters, trial_index = self._ax.get_next_trial() - suggested_config = list(parameters.values()) self._live_index_mapping[trial_id] = trial_index - return dict(zip(self._parameters, suggested_config)) + return parameters def on_trial_result(self, trial_id, result): pass diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 4f962299d51a..e00e5da371c5 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -71,7 +71,7 @@ def test_counting_resources(start_connected_cluster): """Tests that Tune accounting is consistent with actual cluster.""" cluster = start_connected_cluster nodes = [] - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 runner = TrialRunner(BasicVariantGenerator()) kwargs = {"stopping_criterion": {"training_iteration": 10}} @@ -82,17 +82,17 @@ def test_counting_resources(start_connected_cluster): runner.step() # run 1 nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 2 + assert ray.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 runner.step() # run 2 assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1 for i in range(5): nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 6 + assert ray.cluster_resources()["CPU"] == 6 runner.step() # 1 result assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2 @@ -120,7 +120,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): cluster.remove_node(node) cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 for i in range(3): runner.step() diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index ab27cc65d56c..f55dc83362c3 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -33,7 +33,7 @@ def __exit__(self, *args): @pytest.fixture def start_ray(): - ray.init() + ray.init(log_to_driver=False) _register_all() yield ray.shutdown() diff --git a/python/ray/tune/tests/test_track.py b/python/ray/tune/tests/test_track.py new file mode 100644 index 000000000000..d3b6c38d745a --- /dev/null +++ b/python/ray/tune/tests/test_track.py @@ -0,0 +1,84 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pandas as pd +import unittest + +import ray +from ray import tune +from ray.tune import track +from ray.tune.result import EXPR_PARAM_FILE, EXPR_RESULT_FILE + + +def _check_json_val(fname, key, val): + with open(fname, "r") as f: + df = pd.read_json(f, typ="frame", lines=True) + return key in df.columns and (df[key].tail(n=1) == val).all() + + +class TrackApiTest(unittest.TestCase): + def tearDown(self): + track.shutdown() + ray.shutdown() + + def testSessionInitShutdown(self): + self.assertTrue(track._session is None) + + # Checks that the singleton _session is created/destroyed + # by track.init() and track.shutdown() + for _ in range(2): + # do it twice to see that we can reopen the session + track.init(trial_name="test_init") + self.assertTrue(track._session is not None) + track.shutdown() + self.assertTrue(track._session is None) + + def testLogCreation(self): + """Checks that track.init() starts logger and creates log files.""" + track.init(trial_name="test_init") + session = track.get_session() + self.assertTrue(session is not None) + + self.assertTrue(os.path.isdir(session.logdir)) + + params_path = os.path.join(session.logdir, EXPR_PARAM_FILE) + result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) + + self.assertTrue(os.path.exists(params_path)) + self.assertTrue(os.path.exists(result_path)) + self.assertTrue(session.logdir == track.trial_dir()) + + def testMetric(self): + track.init(trial_name="test_log") + session = track.get_session() + for i in range(5): + track.log(test=i) + result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) + self.assertTrue(_check_json_val(result_path, "test", i)) + + def testRayOutput(self): + """Checks that local and remote log format are the same.""" + ray.init() + + def testme(config): + for i in range(config["iters"]): + track.log(iteration=i, hi="test") + + trials = tune.run(testme, config={"iters": 5}) + trial_res = trials[0].last_result + self.assertTrue(trial_res["hi"], "test") + self.assertTrue(trial_res["training_iteration"], 5) + + def testLocalMetrics(self): + """Checks that metric state is updated correctly.""" + track.init(trial_name="test_logs") + session = track.get_session() + self.assertEqual(set(session.trial_config.keys()), {"trial_id"}) + + result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) + track.log(test=1) + self.assertTrue(_check_json_val(result_path, "test", 1)) + track.log(iteration=1, test=2) + self.assertTrue(_check_json_val(result_path, "test", 2)) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 19930559ce7d..a9bf8e3239c6 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -1532,7 +1532,7 @@ def testFailureRecoveryNodeRemoval(self): runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - with patch("ray.global_state.cluster_resources") as resource_mock: + with patch("ray.cluster_resources") as resource_mock: resource_mock.return_value = {"CPU": 1, "GPU": 1} runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) diff --git a/python/ray/tune/track/__init__.py b/python/ray/tune/track/__init__.py new file mode 100644 index 000000000000..a35511e89350 --- /dev/null +++ b/python/ray/tune/track/__init__.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from ray.tune.track.session import TrackSession + +logger = logging.getLogger(__name__) + +_session = None + + +def get_session(): + global _session + if not _session: + raise ValueError("Session not detected. Try `track.init()`?") + return _session + + +def init(ignore_reinit_error=True, **session_kwargs): + """Initializes the global trial context for this process. + + This creates a TrackSession object and the corresponding hooks for logging. + + Examples: + >>> from ray.tune import track + >>> track.init() + """ + global _session + + if _session: + # TODO(ng): would be nice to stack crawl at creation time to report + # where that initial trial was created, and that creation line + # info is helpful to keep around anyway. + reinit_msg = "A session already exists in the current context." + if ignore_reinit_error: + if not _session.is_tune_session: + logger.warning(reinit_msg) + return + else: + raise ValueError(reinit_msg) + + _session = TrackSession(**session_kwargs) + + +def shutdown(): + """Cleans up the trial and removes it from the global context.""" + + global _session + if _session: + _session.close() + _session = None + + +def log(**kwargs): + """Applies TrackSession.log to the trial in the current context.""" + _session = get_session() + return _session.log(**kwargs) + + +def trial_dir(): + """Returns the directory where trial results are saved. + + This includes json data containing the session's parameters and metrics. + """ + _session = get_session() + return _session.logdir + + +__all__ = ["TrackSession", "session", "log", "trial_dir", "init", "shutdown"] diff --git a/python/ray/tune/track/session.py b/python/ray/tune/track/session.py new file mode 100644 index 000000000000..faf850e5fea2 --- /dev/null +++ b/python/ray/tune/track/session.py @@ -0,0 +1,110 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from datetime import datetime + +from ray.tune.trial import Trial +from ray.tune.result import DEFAULT_RESULTS_DIR, TRAINING_ITERATION +from ray.tune.logger import UnifiedLogger, Logger + + +class _ReporterHook(Logger): + def __init__(self, tune_reporter): + self.tune_reporter = tune_reporter + + def on_result(self, metrics): + return self.tune_reporter(**metrics) + + +class TrackSession(object): + """Manages results for a single session. + + Represents a single Trial in an experiment. + + Attributes: + trial_name (str): Custom trial name. + experiment_dir (str): Directory where results for all trials + are stored. Each session is stored into a unique directory + inside experiment_dir. + upload_dir (str): Directory to sync results to. + trial_config (dict): Parameters that will be logged to disk. + _tune_reporter (StatusReporter): For rerouting when using Tune. + Will not instantiate logging if not None. + """ + + def __init__(self, + trial_name="", + experiment_dir=None, + upload_dir=None, + trial_config=None, + _tune_reporter=None): + self._experiment_dir = None + self._logdir = None + self._upload_dir = None + self.trial_config = None + self._iteration = -1 + self.is_tune_session = bool(_tune_reporter) + self.trial_id = Trial.generate_id() + if trial_name: + self.trial_id = trial_name + "_" + self.trial_id + if self.is_tune_session: + self._logger = _ReporterHook(_tune_reporter) + else: + self._initialize_logging(trial_name, experiment_dir, upload_dir, + trial_config) + + def _initialize_logging(self, + trial_name="", + experiment_dir=None, + upload_dir=None, + trial_config=None): + + # TODO(rliaw): In other parts of the code, this is `local_dir`. + if experiment_dir is None: + experiment_dir = os.path.join(DEFAULT_RESULTS_DIR, "default") + + self._experiment_dir = os.path.expanduser(experiment_dir) + + # TODO(rliaw): Refactor `logdir` to `trial_dir`. + self._logdir = Trial.create_logdir(trial_name, self._experiment_dir) + self._upload_dir = upload_dir + self.trial_config = trial_config or {} + + # misc metadata to save as well + self.trial_config["trial_id"] = self.trial_id + self._logger = UnifiedLogger(self.trial_config, self._logdir, + self._upload_dir) + + def log(self, **metrics): + """Logs all named arguments specified in **metrics. + + This will log trial metrics locally, and they will be synchronized + with the driver periodically through ray. + + Arguments: + metrics: named arguments with corresponding values to log. + """ + + # TODO: Implement a batching mechanism for multiple calls to `log` + # within the same iteration. + self._iteration += 1 + metrics_dict = metrics.copy() + metrics_dict.update({"trial_id": self.trial_id}) + + # TODO: Move Trainable autopopulation to a util function + metrics_dict.setdefault(TRAINING_ITERATION, self._iteration) + self._logger.on_result(metrics_dict) + + def close(self): + self.trial_config["trial_completed"] = True + self.trial_config["end_time"] = datetime.now().isoformat() + # TODO(rliaw): Have Tune support updated configs + self._logger.update_config(self.trial_config) + self._logger.close() + + @property + def logdir(self): + """Trial logdir (subdir of given experiment directory)""" + return self._logdir diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index ad61e8d4b393..91ea941b8cf0 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,6 +8,7 @@ from datetime import datetime import logging import json +import uuid import time import tempfile import os @@ -27,7 +28,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID, TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL, EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY) -from ray.utils import _random_string, binary_to_hex, hex_to_binary +from ray.utils import binary_to_hex, hex_to_binary DEBUG_PRINT_INTERVAL = 5 MAX_LEN_IDENTIFIER = 130 @@ -272,7 +273,7 @@ def __init__(self, # Trial config self.trainable_name = trainable_name self.config = config or {} - self.local_dir = os.path.expanduser(local_dir) + self.local_dir = local_dir # This remains unexpanded for syncing. self.experiment_tag = experiment_tag self.resources = ( resources @@ -341,19 +342,23 @@ def _registration_check(cls, trainable_name): @classmethod def generate_id(cls): - return binary_to_hex(_random_string())[:8] + return str(uuid.uuid1().hex)[:8] + + @classmethod + def create_logdir(cls, identifier, local_dir): + local_dir = os.path.expanduser(local_dir) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + return tempfile.mkdtemp( + prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()), + dir=local_dir) def init_logger(self): """Init logger.""" if not self.result_logger: - if not os.path.exists(self.local_dir): - os.makedirs(self.local_dir) if not self.logdir: - self.logdir = tempfile.mkdtemp( - prefix="{}_{}".format( - str(self)[:MAX_LEN_IDENTIFIER], date_str()), - dir=self.local_dir) + self.logdir = Trial.create_logdir(str(self), self.local_dir) elif not os.path.exists(self.logdir): os.makedirs(self.logdir) diff --git a/python/ray/utils.py b/python/ray/utils.py index 0f26aa22d03a..7b87486e325e 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -216,6 +216,10 @@ def binary_to_object_id(binary_object_id): return ray.ObjectID(binary_object_id) +def binary_to_task_id(binary_task_id): + return ray.TaskID(binary_task_id) + + def binary_to_hex(identifier): hex_identifier = binascii.hexlify(identifier) if sys.version_info >= (3, 0): diff --git a/python/ray/worker.py b/python/ray/worker.py index bbcf1bb2235e..c886159aafec 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -25,7 +25,6 @@ import ray.cloudpickle as pickle import ray.experimental.signal as ray_signal import ray.experimental.no_return -import ray.experimental.state as state import ray.gcs_utils import ray.memory_monitor as memory_monitor import ray.node @@ -35,6 +34,7 @@ import ray.serialization as serialization import ray.services as services import ray.signature +import ray.state from ray import ( ActorHandleID, @@ -198,7 +198,7 @@ def task_context(self): # to the current task ID may not be correct. Generate a # random task ID so that the backend can differentiate # between different threads. - self._task_context.current_task_id = TaskID(_random_string()) + self._task_context.current_task_id = TaskID.from_random() if getattr(self, "_multithreading_warned", False) is not True: logger.warning( "Calling ray.get or ray.wait in a separate thread " @@ -1108,8 +1108,6 @@ def get_webui_url(): per worker process. """ -global_state = state.GlobalState() - _global_node = None """ray.node.Node: The global node object that is created by ray.init().""" @@ -1134,14 +1132,6 @@ def print_failed_task(task_status): task_status["error_message"])) -def error_info(): - """Return information about failed tasks.""" - worker = global_worker - worker.check_connected() - return (global_state.error_messages(driver_id=worker.task_driver_id) + - global_state.error_messages(driver_id=DriverID.nil())) - - def _initialize_serialization(driver_id, worker=global_worker): """Initialize the serialization library. @@ -1488,7 +1478,7 @@ def shutdown(exiting_interpreter=False): disconnect() # Disconnect global state from GCS. - global_state.disconnect() + ray.state.state.disconnect() # Shut down the Ray processes. global _global_node @@ -1644,7 +1634,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): try: # Get the exports that occurred before the call to subscribe. - error_messages = global_state.error_messages(worker.task_driver_id) + error_messages = ray.errors(include_cluster_errors=False) for error_message in error_messages: logger.error(error_message) @@ -1725,7 +1715,7 @@ def connect(node, else: # This is the code path of driver mode. if driver_id is None: - driver_id = DriverID(_random_string()) + driver_id = DriverID.from_random() if not isinstance(driver_id, DriverID): raise TypeError("The type of given driver id must be DriverID.") @@ -1774,7 +1764,7 @@ def connect(node, worker.lock = threading.RLock() # Create an object for interfacing with the global state. - global_state._initialize_global_state( + ray.state.state._initialize_global_state( node.redis_address, redis_password=node.redis_password) # Register the worker with Redis. @@ -1834,6 +1824,7 @@ def connect(node, # Create an object store client. worker.plasma_client = thread_safe_client( plasma.connect(node.plasma_store_socket_name, None, 0, 300)) + driver_id_str = _random_string() # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1865,7 +1856,7 @@ def connect(node, function_descriptor.get_function_descriptor_list(), [], # arguments. 0, # num_returns. - TaskID(_random_string()), # parent_task_id. + TaskID(driver_id_str[:TaskID.size()]), # parent_task_id. 0, # parent_counter. ActorID.nil(), # actor_creation_id. ObjectID.nil(), # actor_creation_dummy_object_id. @@ -1880,11 +1871,12 @@ def connect(node, ) # Add the driver task to the task table. - global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + ray.state.state._execute_command(driver_task.task_id(), + "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePubsub.RAYLET_TASK, + driver_task.task_id().binary(), + driver_task._serialized_raylet_task()) # Set the driver's current task ID to the task ID assigned to the # driver task. @@ -1894,7 +1886,7 @@ def connect(node, node.raylet_socket_name, ClientID(worker.worker_id), (mode == WORKER_MODE), - DriverID(worker.current_task_id.binary()), + DriverID(driver_id_str), ) # Start the import thread diff --git a/site/_config.yml b/site/_config.yml index 24ec957d22e7..676d4f0c0bb4 100644 --- a/site/_config.yml +++ b/site/_config.yml @@ -13,10 +13,10 @@ # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. # You can create any custom variable you would like, and they will be accessible # in the templates via {{ site.myvariable }}. -title: "Ray: A Distributed Execution Framework for AI Applications" +title: "Ray: A fast and simple framework for distributed applications" email: "" description: > # this means to ignore newlines until "baseurl:" - Ray is a flexible, high-performance distributed execution framework for AI applications. + Ray is a fast and simple framework for building and running distributed applications. baseurl: "" # the subpath of your site, e.g. /blog url: "" # the base hostname & protocol for your site, e.g. http://example.com github_username: ray-project diff --git a/src/ray/constants.h b/src/ray/constants.h index 2035938be267..c92e6a74aa5d 100644 --- a/src/ray/constants.h +++ b/src/ray/constants.h @@ -4,7 +4,7 @@ #include #include -/// Length of Ray IDs in bytes. +/// Length of Ray full-length IDs in bytes. constexpr int64_t kUniqueIDSize = 20; /// An ObjectID's bytes are split into the task ID itself and the index of the @@ -13,6 +13,9 @@ constexpr int kObjectIdIndexSize = 32; static_assert(kObjectIdIndexSize % CHAR_BIT == 0, "ObjectID prefix not a multiple of bytes"); +/// Length of Ray TaskID in bytes. 32-bit integer is used for object index. +constexpr int64_t kTaskIDSize = kUniqueIDSize - kObjectIdIndexSize / 8; + /// The maximum number of objects that can be returned by a task when finishing /// execution. An ObjectID's bytes are split into the task ID itself and the /// index of the object's creation. A positive index indicates an object diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index d2d225c0a687..7f69c482e5eb 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -89,7 +89,7 @@ void TestTableLookup(const DriverID &driver_id, data->task_specification = "123"; // Check that we added the correct task. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->task_specification, d.task_specification); @@ -104,7 +104,7 @@ void TestTableLookup(const DriverID &driver_id, }; // Check that the lookup does not return an empty entry. - auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; @@ -139,7 +139,7 @@ void TestLogLookup(const DriverID &driver_id, auto data = std::make_shared(); data->node_manager_id = node_manager_id; // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id, d.node_manager_id); @@ -150,7 +150,7 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { @@ -181,11 +181,11 @@ void TestTableLookupFailure(const DriverID &driver_id, TaskID task_id = TaskID::from_random(); // Check that the lookup does not return data. - auto lookup_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id, + auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id); test->Stop(); }; @@ -215,7 +215,7 @@ void TestLogAppendAt(const DriverID &driver_id, } // Check that we added the correct task. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id, + auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); @@ -241,7 +241,7 @@ void TestLogAppendAt(const DriverID &driver_id, /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { @@ -271,7 +271,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli auto data = std::make_shared(); data->manager = manager; // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -297,7 +297,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli data->manager = manager; // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { + gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); @@ -338,7 +338,7 @@ void TestDeleteKeysFromLog( task_id = TaskID::from_random(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id, d.node_manager_id); @@ -350,7 +350,7 @@ void TestDeleteKeysFromLog( for (const auto &task_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); @@ -386,7 +386,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, task_id = TaskID::from_random(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->task_specification, d.task_specification); @@ -434,7 +434,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, object_id = ObjectID::from_random(); ids.push_back(object_id); // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -607,7 +607,7 @@ void TestLogSubscribeAll(const DriverID &driver_id, } // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, - const UniqueID &id, + const DriverID &id, const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. @@ -657,7 +657,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsTableNotificationMode notification_mode, const std::vector data) { if (test->NumCallbacks() < 3 * 3) { @@ -752,7 +752,7 @@ void TestTableSubscribeId(const DriverID &driver_id, // The failure callback should be called once since both keys start as empty. bool failure_notification_received = false; auto failure_callback = [task_id2, &failure_notification_received]( - gcs::AsyncGcsClient *client, const UniqueID &id) { + gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id2); // The failure notification should be the first notification received. ASSERT_EQ(test->NumCallbacks(), 0); @@ -962,7 +962,7 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // The failure callback should not be called since all keys are non-empty // when notifications are requested. - auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; @@ -1188,12 +1188,12 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client ASSERT_EQ(client_id, added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); - ASSERT_EQ(data.is_insertion, is_insertion); + ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); ClientTableDataT cached_client; client->client_table().GetClient(added_id, cached_client); ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.is_insertion, is_insertion); + ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 7acb24d27bd6..b81f388d88c5 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -39,6 +39,14 @@ enum TablePubsub:int { DRIVER, } +// Enum for the entry type in the ClientTable +enum EntryType:int { + INSERTION = 0, + DELETION, + RES_CREATEUPDATE, + RES_DELETE, +} + table Arg { // Object ID for pass-by-reference arguments. Normally there is only one // object ID in this list which represents the object that is being passed. @@ -81,9 +89,8 @@ table TaskInfo { new_actor_handles: string; // Task arguments. args: [Arg]; - // Object IDs of return values. This is a long string that concatenate - // all of the return object IDs of this task. - returns: string; + // Number of return objects. + num_returns: int; // The required_resources vector indicates the quantities of the different // resources required by this task. required_resources: [ResourcePair]; @@ -267,9 +274,8 @@ table ClientTableData { // The port at which the client's object manager is listening for TCP // connections from other object managers. object_manager_port: int; - // True if the message is about the addition of a client and false if it is - // about the deletion of a client. - is_insertion: bool; + // Enum to store the entry type in the log + entry_type: EntryType = INSERTION; resources_total_label: [string]; resources_total_capacity: [double]; } diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index fe61df288d6b..6b03fa735007 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -226,45 +226,6 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { } } -Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, - RedisCallback redisCallback, int log_length) { - int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); - if (length > 0) { - if (log_length >= 0) { - std::string redis_command = command + " %d %d %b %b %d"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length, log_length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } else { - std::string redis_command = command + " %d %d %b %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } - } else { - RAY_CHECK(log_length == -1); - std::string redis_command = command + " %d %d %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size()); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } - return Status::OK(); -} - Status RedisContext::RunArgvAsync(const std::vector &args) { // Build the arguments. std::vector argv; diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 0af5a121e573..93a343464892 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -11,6 +11,12 @@ #include "ray/gcs/format/gcs_generated.h" +extern "C" { +#include "ray/thirdparty/hiredis/adapters/ae.h" +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" +} + struct redisContext; struct redisAsyncContext; struct aeEventLoop; @@ -22,6 +28,8 @@ namespace gcs { /// operation. using RedisCallback = std::function; +void GlobalRedisCallback(void *c, void *r, void *privdata); + class RedisCallbackManager { public: static RedisCallbackManager &instance() { @@ -83,7 +91,8 @@ class RedisContext { /// at which the data must be appended. For all other commands, set to /// -1 for unused. If set, then data must be provided. /// \return Status. - Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data, + template + Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, int64_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -113,6 +122,46 @@ class RedisContext { redisAsyncContext *subscribe_context_; }; +template +Status RedisContext::RunAsync(const std::string &command, const ID &id, + const uint8_t *data, int64_t length, + const TablePrefix prefix, const TablePubsub pubsub_channel, + RedisCallback redisCallback, int log_length) { + int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); + if (length > 0) { + if (log_length >= 0) { + std::string redis_command = command + " %d %d %b %b %d"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size(), data, length, log_length); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } else { + std::string redis_command = command + " %d %d %b %b"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size(), data, length); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } + } else { + RAY_CHECK(log_length == -1); + std::string redis_command = command + " %d %d %b"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size()); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 0405367e15f0..b9891e8cae32 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -676,13 +676,15 @@ int TableDelete_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int size_t len = 0; const char *data_ptr = nullptr; data_ptr = RedisModule_StringPtrLen(data, &len); - REPLY_AND_RETURN_IF_FALSE( - len % kUniqueIDSize == 0, - "The deletion data length must be a multiple of the UniqueID size."); - size_t ids_to_delete = len / kUniqueIDSize; + // The first uint16_t are used to encode the number of ids to delete. + size_t ids_to_delete = *reinterpret_cast(data_ptr); + size_t id_length = (len - sizeof(uint16_t)) / ids_to_delete; + REPLY_AND_RETURN_IF_FALSE((len - sizeof(uint16_t)) % ids_to_delete == 0, + "The deletion data length must be multiple of the ID size"); + data_ptr += sizeof(uint16_t); for (size_t i = 0; i < ids_to_delete; ++i) { RedisModuleString *id_data = - RedisModule_CreateString(ctx, data_ptr + i * kUniqueIDSize, kUniqueIDSize); + RedisModule_CreateString(ctx, data_ptr + i * id_length, id_length); RAY_IGNORE_EXPR(DeleteKeyHelper(ctx, prefix_str, id_data)); } return RedisModule_ReplyWithSimpleString(ctx, "OK"); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index e0876aa73e3e..3d4708940d1a 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -192,15 +192,25 @@ void Log::Delete(const DriverID &driver_id, const std::vector &ids } // Breaking really large deletion commands into batches of smaller size. const size_t batch_size = - RayConfig::instance().maximum_gcs_deletion_batch_size() * kUniqueIDSize; + RayConfig::instance().maximum_gcs_deletion_batch_size() * ID::size(); for (const auto &pair : sharded_data) { std::string current_data = pair.second.str(); for (size_t cur = 0; cur < pair.second.str().size(); cur += batch_size) { - RAY_IGNORE_EXPR(pair.first->RunAsync( - "RAY.TABLE_DELETE", UniqueID::nil(), - reinterpret_cast(current_data.c_str() + cur), - std::min(batch_size, current_data.size() - cur), prefix_, pubsub_channel_, - /*redisCallback=*/nullptr)); + size_t data_field_size = std::min(batch_size, current_data.size() - cur); + uint16_t id_count = data_field_size / ID::size(); + // Send data contains id count and all the id data. + std::string send_data(data_field_size + sizeof(id_count), 0); + uint8_t *buffer = reinterpret_cast(&send_data[0]); + *reinterpret_cast(buffer) = id_count; + RAY_IGNORE_EXPR( + std::copy_n(reinterpret_cast(current_data.c_str() + cur), + data_field_size, buffer + sizeof(uint16_t))); + + RAY_IGNORE_EXPR( + pair.first->RunAsync("RAY.TABLE_DELETE", UniqueID::nil(), + reinterpret_cast(send_data.c_str()), + send_data.size(), prefix_, pubsub_channel_, + /*redisCallback=*/nullptr)); } } } @@ -363,7 +373,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && entry.second.is_insertion) { + if (!entry.first.is_nil() && (entry.second.entry_type == EntryType::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -373,55 +383,136 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && !entry.second.is_insertion) { + if (!entry.first.is_nil() && entry.second.entry_type == EntryType::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } } +void ClientTable::RegisterResourceCreateUpdatedCallback( + const ClientTableCallback &callback) { + resource_createupdated_callback_ = callback; + // Call the callback for any clients that are cached. + for (const auto &entry : client_cache_) { + if (!entry.first.is_nil() && + (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + resource_createupdated_callback_(client_, entry.first, entry.second); + } + } +} + +void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &callback) { + resource_deleted_callback_ = callback; + // Call the callback for any clients that are cached. + for (const auto &entry : client_cache_) { + if (!entry.first.is_nil() && entry.second.entry_type == EntryType::RES_DELETE) { + resource_deleted_callback_(client_, entry.first, entry.second); + } + } +} + void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientTableDataT &data) { ClientID client_id = ClientID::from_binary(data.client_id); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); - bool is_new; + bool is_notif_new; if (entry == client_cache_.end()) { // If the entry is not in the cache, then the notification is new. - is_new = true; + is_notif_new = true; } else { // If the entry is in the cache, then the notification is new if the client - // was alive and is now dead. - bool was_inserted = entry->second.is_insertion; - bool is_deleted = !data.is_insertion; - is_new = (was_inserted && is_deleted); + // was alive and is now dead or resources have been updated. + bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); + bool is_deleted = (data.entry_type == EntryType::DELETION); + bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)); + is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (!entry->second.is_insertion) { - RAY_CHECK(!data.is_insertion) + if (entry->second.entry_type == EntryType::DELETION) { + RAY_CHECK((data.entry_type == EntryType::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } } // Add the notification to our cache. Notifications are idempotent. - client_cache_[client_id] = data; + // If it is a new client or a client removal, add as is + if ((data.entry_type == EntryType::INSERTION) || + (data.entry_type == EntryType::DELETION)) { + RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " + "notification for client id " + << client_id << ". EntryType: " << int(data.entry_type) + << ". Setting the client cache to data."; + client_cache_[client_id] = data; + } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)) { + RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " + "notification for client id " + << client_id << ". EntryType: " << int(data.entry_type) + << ". Updating the client cache with the delta from the log."; + + ClientTableDataT &cache_data = client_cache_[client_id]; + // Iterate over all resources in the new create/update notification + for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { + auto const &resource_name = data.resources_total_label[i]; + auto const &capacity = data.resources_total_capacity[i]; + + // If resource exists in the ClientTableData, update it, else create it + auto existing_resource_label = + std::find(cache_data.resources_total_label.begin(), + cache_data.resources_total_label.end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label.end()) { + auto index = std::distance(cache_data.resources_total_label.begin(), + existing_resource_label); + // Resource already exists, set capacity if updation call.. + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_capacity[index] = capacity; + } + // .. delete if deletion call. + else if (data.entry_type == EntryType::RES_DELETE) { + cache_data.resources_total_label.erase( + cache_data.resources_total_label.begin() + index); + cache_data.resources_total_capacity.erase( + cache_data.resources_total_capacity.begin() + index); + } + } else { + // Resource does not exist, create resource and add capacity if it was a resource + // create call. + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_label.push_back(resource_name); + cache_data.resources_total_capacity.push_back(capacity); + } + } + } + } // If the notification is new, call any registered callbacks. - if (is_new) { - if (data.is_insertion) { + ClientTableDataT &cache_data = client_cache_[client_id]; + if (is_notif_new) { + if (data.entry_type == EntryType::INSERTION) { if (client_added_callback_ != nullptr) { - client_added_callback_(client, client_id, data); + client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else { + } else if (data.entry_type == EntryType::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. removed_clients_.insert(client_id); if (client_removed_callback_ != nullptr) { - client_removed_callback_(client, client_id, data); + client_removed_callback_(client, client_id, cache_data); + } + } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + if (resource_createupdated_callback_ != nullptr) { + resource_createupdated_callback_(client, client_id, cache_data); + } + } else if (data.entry_type == EntryType::RES_DELETE) { + if (resource_deleted_callback_ != nullptr) { + resource_deleted_callback_(client, client_id, cache_data); } } } @@ -449,7 +540,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Construct the data to add to the client table. auto data = std::make_shared(local_client_); - data->is_insertion = true; + data->entry_type = EntryType::INSERTION; // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, @@ -467,7 +558,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.is_insertion) { + if (notification.entry_type != EntryType::DELETION) { connected_nodes.emplace(notification.client_id, notification); } else { auto iter = connected_nodes.find(notification.client_id); @@ -498,7 +589,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto data = std::make_shared(local_client_); - data->is_insertion = false; + data->entry_type = EntryType::DELETION; auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { HandleConnected(client, data); @@ -516,7 +607,7 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { auto data = std::make_shared(); data->client_id = dead_client_id.binary(); - data->is_insertion = false; + data->entry_type = EntryType::DELETION; return Append(DriverID::nil(), client_log_key_, data, nullptr); } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index b229108328f9..58a087d8c666 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -206,7 +206,7 @@ class Log : public LogInterface, virtual public PubsubInterface { protected: std::shared_ptr GetRedisContext(const ID &id) { - static std::hash index; + static std::hash index; return shard_contexts_[index(id) % shard_contexts_.size()]; } @@ -677,7 +677,7 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : private Log { +class ClientTable : public Log { public: using ClientTableCallback = std::function; @@ -729,6 +729,16 @@ class ClientTable : private Log { /// \param callback The callback to register. void RegisterClientRemovedCallback(const ClientTableCallback &callback); + /// Register a callback to call when a resource is created or updated. + /// + /// \param callback The callback to register. + void RegisterResourceCreateUpdatedCallback(const ClientTableCallback &callback); + + /// Register a callback to call when a resource is deleted. + /// + /// \param callback The callback to register. + void RegisterResourceDeletedCallback(const ClientTableCallback &callback); + /// Get a client's information from the cache. The cache only contains /// information for clients that we've heard a notification for. /// @@ -772,16 +782,16 @@ class ClientTable : private Log { /// \return string. std::string DebugString() const; + /// The key at which the log of client information is stored. This key must + /// be kept the same across all instances of the ClientTable, so that all + /// clients append and read from the same key. + ClientID client_log_key_; + private: /// Handle a client table notification. void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); /// Handle this client's successful connection to the GCS. void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); - - /// The key at which the log of client information is stored. This key must - /// be kept the same across all instances of the ClientTable, so that all - /// clients append and read from the same key. - ClientID client_log_key_; /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. @@ -792,6 +802,10 @@ class ClientTable : private Log { ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. ClientTableCallback client_removed_callback_; + /// The callback to call when a resource is created or updated. + ClientTableCallback resource_createupdated_callback_; + /// The callback to call when a resource is deleted. + ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. std::unordered_map client_cache_; /// The set of removed clients. diff --git a/src/ray/id.cc b/src/ray/id.cc index 8d72cef8b300..a011430ad1cf 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -26,82 +26,16 @@ std::mt19937 RandomlySeededMersenneTwister() { uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); -UniqueID::UniqueID() { - // Set the ID to nil. - std::fill_n(id_, kUniqueIDSize, 255); -} - -UniqueID::UniqueID(const std::string &binary) { - std::memcpy(&id_, binary.data(), kUniqueIDSize); -} - -UniqueID::UniqueID(const plasma::UniqueID &from) { - std::memcpy(&id_, from.data(), kUniqueIDSize); -} - -UniqueID UniqueID::from_random() { - std::string data(kUniqueIDSize, 0); - // NOTE(pcm): The right way to do this is to have one std::mt19937 per - // thread (using the thread_local keyword), but that's not supported on - // older versions of macOS (see https://stackoverflow.com/a/29929949) - static std::mutex random_engine_mutex; - std::lock_guard lock(random_engine_mutex); - static std::mt19937 generator = RandomlySeededMersenneTwister(); - std::uniform_int_distribution dist(0, std::numeric_limits::max()); - for (int i = 0; i < kUniqueIDSize; i++) { - data[i] = static_cast(dist(generator)); - } - return UniqueID::from_binary(data); -} - -UniqueID UniqueID::from_binary(const std::string &binary) { return UniqueID(binary); } - -const UniqueID &UniqueID::nil() { - static const UniqueID nil_id; - return nil_id; -} - -bool UniqueID::is_nil() const { - const uint8_t *d = data(); - for (int i = 0; i < kUniqueIDSize; ++i) { - if (d[i] != 255) { - return false; - } - } - return true; -} - -const uint8_t *UniqueID::data() const { return id_; } - -size_t UniqueID::size() { return kUniqueIDSize; } - -std::string UniqueID::binary() const { - return std::string(reinterpret_cast(id_), kUniqueIDSize); -} - -std::string UniqueID::hex() const { - constexpr char hex[] = "0123456789abcdef"; - std::string result; - for (int i = 0; i < kUniqueIDSize; i++) { - unsigned int val = id_[i]; - result.push_back(hex[val >> 4]); - result.push_back(hex[val & 0xf]); - } - return result; -} - -plasma::UniqueID UniqueID::to_plasma_id() const { +plasma::UniqueID ObjectID::to_plasma_id() const { plasma::UniqueID result; - std::memcpy(result.mutable_data(), &id_, kUniqueIDSize); + std::memcpy(result.mutable_data(), data(), kUniqueIDSize); return result; } -bool UniqueID::operator==(const UniqueID &rhs) const { - return std::memcmp(data(), rhs.data(), kUniqueIDSize) == 0; +ObjectID::ObjectID(const plasma::UniqueID &from) { + std::memcpy(this->mutable_data(), from.data(), kUniqueIDSize); } -bool UniqueID::operator!=(const UniqueID &rhs) const { return !(*this == rhs); } - // This code is from https://sites.google.com/site/murmurhash/ // and is public domain. uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { @@ -151,60 +85,32 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { return h; } -size_t UniqueID::hash() const { - // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is - // default value 0) - if (!hash_) { - hash_ = MurmurHash64A(&id_[0], kUniqueIDSize, 0); - } - return hash_; +TaskID TaskID::GetDriverTaskID(const DriverID &driver_id) { + std::string driver_id_str = driver_id.binary(); + driver_id_str.resize(size()); + return TaskID::from_binary(driver_id_str); } -std::ostream &operator<<(std::ostream &os, const UniqueID &id) { - if (id.is_nil()) { - os << "NIL_ID"; - } else { - os << id.hex(); - } - return os; +TaskID ObjectID::task_id() const { + return TaskID::from_binary( + std::string(reinterpret_cast(id_), TaskID::size())); } -const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { - RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts); - ObjectID return_id = ObjectID(task_id); - int64_t *first_bytes = reinterpret_cast(&return_id); - // Zero out the lowest kObjectIdIndexSize bits of the first byte of the - // object ID. - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - *first_bytes = *first_bytes & (bitmask); - // OR the first byte of the object ID with the return index. - *first_bytes = *first_bytes | (object_index & ~bitmask); - return return_id; +ObjectID ObjectID::for_put(const TaskID &task_id, int64_t put_index) { + RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts) << "index=" << put_index; + ObjectID object_id; + std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + object_id.index_ = -put_index; + return object_id; } -const TaskID FinishTaskId(const TaskID &task_id) { - return TaskID(ComputeObjectId(task_id, 0)); -} - -const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) { - RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns); - return ComputeObjectId(task_id, return_index); -} - -const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) { - RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts); - // We multiply put_index by -1 to distinguish from return_index. - return ComputeObjectId(task_id, -1 * put_index); -} - -const TaskID ComputeTaskId(const ObjectID &object_id) { - TaskID task_id = TaskID(object_id); - int64_t *first_bytes = reinterpret_cast(&task_id); - // Zero out the lowest kObjectIdIndexSize bits of the first byte of the - // object ID. - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - *first_bytes = *first_bytes & (bitmask); - return task_id; +ObjectID ObjectID::for_task_return(const TaskID &task_id, int64_t return_index) { + RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) << "index=" + << return_index; + ObjectID object_id; + std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + object_id.index_ = return_index; + return object_id; } const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, @@ -220,16 +126,21 @@ const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task // Compute the final task ID from the hash. BYTE buff[DIGEST_SIZE]; sha256_final(&ctx, buff); - return FinishTaskId(TaskID::from_binary(std::string(buff, buff + kUniqueIDSize))); + return TaskID::from_binary(std::string(buff, buff + TaskID::size())); } -int64_t ComputeObjectIndex(const ObjectID &object_id) { - const int64_t *first_bytes = reinterpret_cast(&object_id); - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - int64_t index = *first_bytes & (~bitmask); - index <<= (8 * sizeof(int64_t) - kObjectIdIndexSize); - index >>= (8 * sizeof(int64_t) - kObjectIdIndexSize); - return index; -} +#define ID_OSTREAM_OPERATOR(id_type) \ + std::ostream &operator<<(std::ostream &os, const id_type &id) { \ + if (id.is_nil()) { \ + os << "NIL_ID"; \ + } else { \ + os << id.hex(); \ + } \ + return os; \ + } + +ID_OSTREAM_OPERATOR(UniqueID); +ID_OSTREAM_OPERATOR(TaskID); +ID_OSTREAM_OPERATOR(ObjectID); } // namespace ray diff --git a/src/ray/id.h b/src/ray/id.h index 9467c1a3f11d..f90f66549358 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -2,44 +2,128 @@ #define RAY_ID_H_ #include +#include +#include #include +#include +#include #include #include "plasma/common.h" #include "ray/constants.h" +#include "ray/util/logging.h" #include "ray/util/visibility.h" namespace ray { -class RAY_EXPORT UniqueID { +class DriverID; +class UniqueID; + +// Declaration. +std::mt19937 RandomlySeededMersenneTwister(); +uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); + +// Change the compiler alignment to 1 byte (default is 8). +#pragma pack(push, 1) + +template +class BaseID { public: - UniqueID(); - UniqueID(const plasma::UniqueID &from); - static UniqueID from_random(); - static UniqueID from_binary(const std::string &binary); - static const UniqueID &nil(); + BaseID(); + static T from_random(); + static T from_binary(const std::string &binary); + static const T &nil(); + static size_t size() { return T::size(); } + size_t hash() const; bool is_nil() const; - bool operator==(const UniqueID &rhs) const; - bool operator!=(const UniqueID &rhs) const; + bool operator==(const BaseID &rhs) const; + bool operator!=(const BaseID &rhs) const; const uint8_t *data() const; - static size_t size(); std::string binary() const; std::string hex() const; - plasma::UniqueID to_plasma_id() const; - private: + protected: + BaseID(const std::string &binary) { + std::memcpy(const_cast(this->data()), binary.data(), T::size()); + } + // All IDs are immutable for hash evaluations. mutable_data is only allow to use + // in construction time, so this function is protected. + uint8_t *mutable_data(); + // For lazy evaluation, be careful to have one Id contained in another. + // This hash code will be duplicated. + mutable size_t hash_ = 0; +}; + +class UniqueID : public BaseID { + public: + UniqueID() : BaseID(){}; + static size_t size() { return kUniqueIDSize; } + + protected: UniqueID(const std::string &binary); protected: uint8_t id_[kUniqueIDSize]; - mutable size_t hash_ = 0; }; -static_assert(std::is_standard_layout::value, "UniqueID must be standard"); +class TaskID : public BaseID { + public: + TaskID() : BaseID() {} + static size_t size() { return kTaskIDSize; } + static TaskID GetDriverTaskID(const DriverID &driver_id); + + private: + uint8_t id_[kTaskIDSize]; +}; + +class ObjectID : public BaseID { + public: + ObjectID() : BaseID() {} + static size_t size() { return kUniqueIDSize; } + plasma::ObjectID to_plasma_id() const; + ObjectID(const plasma::UniqueID &from); + + /// Get the index of this object in the task that created it. + /// + /// \return The index of object creation according to the task that created + /// this object. This is positive if the task returned the object and negative + /// if created by a put. + int32_t object_index() const { return index_; } + + /// Compute the task ID of the task that created the object. + /// + /// \return The task ID of the task that created this object. + TaskID task_id() const; + + /// Compute the object ID of an object put by the task. + /// + /// \param task_id The task ID of the task that created the object. + /// \param index What index of the object put in the task. + /// \return The computed object ID. + static ObjectID for_put(const TaskID &task_id, int64_t put_index); + + /// Compute the object ID of an object returned by the task. + /// + /// \param task_id The task ID of the task that created the object. + /// \param return_index What index of the object returned by in the task. + /// \return The computed object ID. + static ObjectID for_task_return(const TaskID &task_id, int64_t return_index); + + private: + uint8_t id_[kTaskIDSize]; + int32_t index_; +}; + +static_assert(sizeof(TaskID) == kTaskIDSize + sizeof(size_t), + "TaskID size is not as expected"); +static_assert(sizeof(ObjectID) == sizeof(int32_t) + sizeof(TaskID), + "ObjectID size is not as expected"); std::ostream &operator<<(std::ostream &os, const UniqueID &id); +std::ostream &operator<<(std::ostream &os, const TaskID &id); +std::ostream &operator<<(std::ostream &os, const ObjectID &id); #define DEFINE_UNIQUE_ID(type) \ class RAY_EXPORT type : public UniqueID { \ @@ -63,35 +147,8 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id); #undef DEFINE_UNIQUE_ID -// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we -// can make these methods of the derived classes. -/// Finish computing a task ID. Since objects created by the task share a -/// prefix of the ID, the suffix of the task ID is zeroed out by this function. -/// -/// \param task_id A task ID to finish. -/// \return The finished task ID. It may now be used to compute IDs for objects -/// created by the task. -const TaskID FinishTaskId(const TaskID &task_id); - -/// Compute the object ID of an object returned by the task. -/// -/// \param task_id The task ID of the task that created the object. -/// \param return_index What number return value this object is in the task. -/// \return The computed object ID. -const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index); - -/// Compute the object ID of an object put by the task. -/// -/// \param task_id The task ID of the task that created the object. -/// \param put_index What number put this object was created by in the task. -/// \return The computed object ID. -const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index); - -/// Compute the task ID of the task that created the object. -/// -/// \param object_id The object ID. -/// \return The task ID of the task that created this object. -const TaskID ComputeTaskId(const ObjectID &object_id); +// Restore the compiler alignment to defult (8 bytes). +#pragma pack(pop) /// Generate a task ID from the given info. /// @@ -102,13 +159,95 @@ const TaskID ComputeTaskId(const ObjectID &object_id); const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, int parent_task_counter); -/// Compute the index of this object in the task that created it. -/// -/// \param object_id The object ID. -/// \return The index of object creation according to the task that created -/// this object. This is positive if the task returned the object and negative -/// if created by a put. -int64_t ComputeObjectIndex(const ObjectID &object_id); +template +BaseID::BaseID() { + // Using const_cast to directly change data is dangerous. The cached + // hash may not be changed. This is used in construction time. + std::fill_n(this->mutable_data(), T::size(), 0xff); +} + +template +T BaseID::from_random() { + std::string data(T::size(), 0); + // NOTE(pcm): The right way to do this is to have one std::mt19937 per + // thread (using the thread_local keyword), but that's not supported on + // older versions of macOS (see https://stackoverflow.com/a/29929949) + static std::mutex random_engine_mutex; + std::lock_guard lock(random_engine_mutex); + static std::mt19937 generator = RandomlySeededMersenneTwister(); + std::uniform_int_distribution dist(0, std::numeric_limits::max()); + for (int i = 0; i < T::size(); i++) { + data[i] = static_cast(dist(generator)); + } + return T::from_binary(data); +} + +template +T BaseID::from_binary(const std::string &binary) { + T t = T::nil(); + std::memcpy(t.mutable_data(), binary.data(), T::size()); + return t; +} + +template +const T &BaseID::nil() { + static const T nil_id; + return nil_id; +} + +template +bool BaseID::is_nil() const { + static T nil_id = T::nil(); + return *this == nil_id; +} + +template +size_t BaseID::hash() const { + // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is + // default value 0) + if (!hash_) { + hash_ = MurmurHash64A(data(), T::size(), 0); + } + return hash_; +} + +template +bool BaseID::operator==(const BaseID &rhs) const { + return std::memcmp(data(), rhs.data(), T::size()) == 0; +} + +template +bool BaseID::operator!=(const BaseID &rhs) const { + return !(*this == rhs); +} + +template +uint8_t *BaseID::mutable_data() { + return reinterpret_cast(this) + sizeof(hash_); +} + +template +const uint8_t *BaseID::data() const { + return reinterpret_cast(this) + sizeof(hash_); +} + +template +std::string BaseID::binary() const { + return std::string(reinterpret_cast(data()), T::size()); +} + +template +std::string BaseID::hex() const { + constexpr char hex[] = "0123456789abcdef"; + const uint8_t *id = data(); + std::string result; + for (int i = 0; i < T::size(); i++) { + unsigned int val = id[i]; + result.push_back(hex[val >> 4]); + result.push_back(hex[val & 0xf]); + } + return result; +} } // namespace ray @@ -125,6 +264,8 @@ namespace std { }; DEFINE_UNIQUE_ID(UniqueID); +DEFINE_UNIQUE_ID(TaskID); +DEFINE_UNIQUE_ID(ObjectID); #include "id_def.h" #undef DEFINE_UNIQUE_ID diff --git a/src/ray/id_def.h b/src/ray/id_def.h index 8a5e7e943262..96c7d59d1098 100644 --- a/src/ray/id_def.h +++ b/src/ray/id_def.h @@ -4,8 +4,6 @@ // Macro definition format: DEFINE_UNIQUE_ID(id_type). // NOTE: This file should NOT be included in any file other than id.h. -DEFINE_UNIQUE_ID(TaskID) -DEFINE_UNIQUE_ID(ObjectID) DEFINE_UNIQUE_ID(FunctionID) DEFINE_UNIQUE_ID(ActorClassID) DEFINE_UNIQUE_ID(ActorID) diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 99ed0851cfb9..85157abcdbe9 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -108,7 +108,7 @@ void ObjectDirectory::LookupRemoteConnectionInfo( ClientID result_client_id = ClientID::from_binary(client_data.client_id); if (!result_client_id.is_nil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.is_insertion) { + if (client_data.entry_type == EntryType::INSERTION) { connection_info.ip = client_data.node_manager_address; connection_info.port = static_cast(client_data.object_manager_port); } diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index 746f4d622d5a..5245a94ace3a 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -42,6 +42,11 @@ void ObjectStoreNotificationManager::NotificationWait() { void ObjectStoreNotificationManager::ProcessStoreLength( const boost::system::error_code &error) { notification_.resize(length_); + if (error) { + RAY_LOG(FATAL) + << "Problem communicating with the object store from raylet, check logs or " + << "dmesg for previous errors: " << boost_to_ray_status(error).ToString(); + } boost::asio::async_read( socket_, boost::asio::buffer(notification_), boost::bind(&ObjectStoreNotificationManager::ProcessStoreNotification, this, @@ -50,7 +55,7 @@ void ObjectStoreNotificationManager::ProcessStoreLength( void ObjectStoreNotificationManager::ProcessStoreNotification( const boost::system::error_code &error) { - if (error.value() != boost::system::errc::success) { + if (error) { RAY_LOG(FATAL) << "Problem communicating with the object store from raylet, check logs or " << "dmesg for previous errors: " << boost_to_ray_status(error).ToString(); diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index a373ea9b9365..98eeb9186192 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -288,7 +288,7 @@ class TestObjectManager : public TestObjectManagerBase { // object. ObjectID object_1 = WriteDataToClient(client2, data_size); ObjectID object_2 = WriteDataToClient(client2, data_size); - UniqueID sub_id = ray::ObjectID::from_random(); + UniqueID sub_id = ray::UniqueID::from_random(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, [this, sub_id, object_1, object_2]( diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index f673e2251548..a5b041f29cae 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -77,6 +77,8 @@ enum MessageType:int { NotifyActorResumedFromCheckpoint, // A node manager requests to connect to another node manager. ConnectClient, + // Set dynamic custom resource + SetResourceRequest, } table TaskExecutionSpecification { @@ -234,3 +236,12 @@ table ConnectClient { // ID of the connecting client. client_id: string; } + +table SetResourceRequest{ + // Name of the resource to be set + resource_name: string; + // Capacity of the resource to be set + capacity: double; + // Client ID where this resource will be set + client_id: string; +} diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index eb9d2f0e5a83..ac32911ef2d0 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -302,6 +302,24 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo ThrowRayExceptionIfNotOK(env, status); } +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(JNIEnv *env, jclass, + jlong client, jstring resourceName, jdouble capacity, jbyteArray nodeId) { + auto raylet_client = reinterpret_cast(client); + UniqueIdFromJByteArray node_id(env, nodeId); + const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); + + auto status = raylet_client->SetResource(native_resource_name, + static_cast(capacity), node_id.GetId()); + env->ReleaseStringUTFChars(resourceName, native_resource_name); + ThrowRayExceptionIfNotOK(env, status); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index c00c7c009814..91338a12e176 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -116,6 +116,14 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( + JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 94f1dc11f189..4c3fac24f19e 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -48,7 +48,7 @@ void LineageEntry::ComputeParentTaskIds() { parent_task_ids_.clear(); // A task's parents are the tasks that created its arguments. for (const auto &dependency : task_.GetDependencies()) { - parent_task_ids_.insert(ComputeTaskId(dependency)); + parent_task_ids_.insert(dependency.task_id()); } } diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 51e035b1b1e6..1e20fe3f4131 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -52,7 +52,8 @@ void Monitor::Tick() { const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.binary() == data.client_id && !data.is_insertion) { + if (client_id.binary() == data.client_id && + data.entry_type == EntryType::DELETION) { // The node has been marked dead by itself. marked = true; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index f901c6800eb5..2e25407f12fb 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -185,6 +185,22 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); + // Register a callback on the client table for resource create/update requests + auto node_manager_resource_createupdated = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ResourceCreateUpdated(data); + }; + gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( + node_manager_resource_createupdated); + + // Register a callback on the client table for resource delete requests + auto node_manager_resource_deleted = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ResourceDeleted(data); + }; + gcs_client_->client_table().RegisterResourceDeletedCallback( + node_manager_resource_deleted); + // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this]( gcs::AsyncGcsClient *client, const ClientID &id, @@ -461,6 +477,92 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { object_directory_->HandleClientRemoved(client_id); } +void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " + << client_id << ". Updating resource map."; + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); + + const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); + ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] The difference in the resource map is " + << difference_set.ToString(); + + SchedulingResources &cluster_schedres = cluster_resource_map_[client_id]; + + // Update local_available_resources_ and SchedulingResources + for (const auto &resource_pair : difference_set.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + const double &new_resource_capacity = resource_pair.second; + + cluster_schedres.UpdateResource(resource_label, new_resource_capacity); + if (client_id == local_client_id) { + local_available_resources_.AddOrUpdateResource(resource_label, + new_resource_capacity); + } + } + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; + + if (client_id == local_client_id) { + // The resource update is on the local node, check if we can reschedule tasks. + TryLocalInfeasibleTaskScheduling(); + } + return; +} + +void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); + RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id + << " with new resources: " << new_res_set.ToString() + << ". Updating resource map."; + + const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); + ResourceSet deleted_set = old_res_set.FindDeletedResources(new_res_set); + RAY_LOG(DEBUG) << "[ResourceDeleted] The difference in the resource map is " + << deleted_set.ToString(); + + SchedulingResources &cluster_schedres = cluster_resource_map_[client_id]; + + // Update local_available_resources_ and SchedulingResources + for (const auto &resource_pair : deleted_set.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + + cluster_schedres.DeleteResource(resource_label); + if (client_id == local_client_id) { + local_available_resources_.DeleteResource(resource_label); + } + } + RAY_LOG(DEBUG) << "[ResourceDeleted] Updated cluster_resource_map."; + return; +} + +void NodeManager::TryLocalInfeasibleTaskScheduling() { + RAY_LOG(DEBUG) << "[LocalResourceUpdateRescheduler] The resource update is on the " + "local node, check if we can reschedule tasks"; + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + SchedulingResources &new_local_resources = cluster_resource_map_[local_client_id]; + + // SpillOver locally to figure out which infeasible tasks can be placed now + std::vector decision = scheduling_policy_.SpillOver(new_local_resources); + + std::unordered_set local_task_ids(decision.begin(), decision.end()); + + // Transition locally placed tasks to waiting or ready for dispatch. + if (local_task_ids.size() > 0) { + std::vector tasks = local_queues_.RemoveTasks(local_task_ids); + for (const auto &t : tasks) { + EnqueuePlaceableTask(t); + } + } +} + void NodeManager::HeartbeatAdded(const ClientID &client_id, const HeartbeatTableDataT &heartbeat_data) { // Locate the client id in remote client table and update available resources based on @@ -672,7 +774,10 @@ void NodeManager::DispatchTasks( } } } - local_queues_.RemoveTasks(removed_task_ids); + // Move the ASSIGNED task to the SWAP queue so that we remember that we have + // it queued locally. Once the GetTaskReply has been sent, the task will get + // re-queued, depending on whether the message succeeded or not. + local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP); } void NodeManager::ProcessClientMessage( @@ -718,6 +823,9 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::SubmitTask: { ProcessSubmitTaskMessage(message_data); } break; + case protocol::MessageType::SetResourceRequest: { + ProcessSetResourceRequest(client, message_data); + } break; case protocol::MessageType::FetchOrReconstruct: { ProcessFetchOrReconstructMessage(client, message_data); } break; @@ -744,7 +852,7 @@ void NodeManager::ProcessClientMessage( // Clean up their creating tasks from GCS. std::vector creating_task_ids; for (const auto &object_id : object_ids) { - creating_task_ids.push_back(ComputeTaskId(object_id)); + creating_task_ids.push_back(object_id.task_id()); } gcs_client_->raylet_task_table().Delete(DriverID::nil(), creating_task_ids); } @@ -779,11 +887,12 @@ void NodeManager::ProcessRegisterClientRequestMessage( // message is actually the ID of the driver task, while client_id represents the // real driver ID, which can associate all the tasks/actors for a given driver, // which is set to the worker ID. - const DriverID driver_task_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(TaskID(driver_task_id)); + const DriverID driver_id = from_flatbuf(*message->driver_id()); + TaskID driver_task_id = TaskID::GetDriverTaskID(driver_id); + worker->AssignTaskId(driver_task_id); worker->AssignDriverId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(TaskID(driver_task_id)); + local_queues_.AddDriverTaskId(driver_task_id); } } @@ -931,12 +1040,14 @@ void NodeManager::ProcessDisconnectClientMessage( // Return the resources that were being used by this worker. auto const &task_resources = worker->GetTaskResourceIds(); - local_available_resources_.Release(task_resources); + local_available_resources_.ReleaseConstrained( + task_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[client_id].Release(task_resources.ToResourceSet()); worker->ResetTaskResourceIds(); auto const &lifetime_resources = worker->GetLifetimeResourceIds(); - local_available_resources_.Release(lifetime_resources); + local_available_resources_.ReleaseConstrained( + lifetime_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); worker->ResetLifetimeResourceIds(); @@ -1170,6 +1281,59 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl node_manager_client.ProcessMessages(); } +void NodeManager::ProcessSetResourceRequest( + const std::shared_ptr &client, const uint8_t *message_data) { + // Read the SetResource message + auto message = flatbuffers::GetRoot(message_data); + + auto const &resource_name = string_from_flatbuf(*message->resource_name()); + double const &capacity = message->capacity(); + bool is_deletion = capacity <= 0; + + ClientID client_id = from_flatbuf(*message->client_id()); + + // If the python arg was null, set client_id to the local client + if (client_id.is_nil()) { + client_id = gcs_client_->client_table().GetLocalClientId(); + } + + if (is_deletion && + cluster_resource_map_[client_id].GetTotalResources().GetResourceMap().count( + resource_name) == 0) { + // Resource does not exist in the cluster resource map, thus nothing to delete. + // Return.. + RAY_LOG(INFO) << "[ProcessDeleteResourceRequest] Trying to delete resource " + << resource_name << ", but it does not exist. Doing nothing.."; + return; + } + + // Add the new resource to a skeleton ClientTableDataT object + ClientTableDataT data; + gcs_client_->client_table().GetClient(client_id, data); + // Replace the resource vectors with the resource deltas from the message. + // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in + // the resources + data.resources_total_label = std::vector{resource_name}; + data.resources_total_capacity = std::vector{capacity}; + // Set the correct flag for entry_type + if (is_deletion) { + data.entry_type = EntryType::RES_DELETE; + } else { + data.entry_type = EntryType::RES_CREATEUPDATE; + } + + // Submit to the client table. This calls the ResourceCreateUpdated callback, which + // updates cluster_resource_map_. + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (not worker) { + worker = worker_pool_.GetRegisteredDriver(client); + } + auto data_shared_ptr = std::make_shared(data); + auto client_table = gcs_client_->client_table(); + RAY_CHECK_OK(gcs_client_->client_table().Append( + DriverID::nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); +} + void NodeManager::ScheduleTasks( std::unordered_map &resource_map) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -1665,11 +1829,15 @@ bool NodeManager::AssignTask(const Task &task) { auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), fbb.CreateVector(resource_id_set_flatbuf)); fbb.Finish(message); - // Give the callback a copy of the task so it can modify it. - Task assigned_task(task); + const auto &task_id = spec.TaskId(); worker->Connection()->WriteMessageAsync( static_cast(protocol::MessageType::ExecuteTask), fbb.GetSize(), - fbb.GetBufferPointer(), [this, worker, assigned_task](ray::Status status) mutable { + fbb.GetBufferPointer(), [this, worker, task_id](ray::Status status) { + // Remove the ASSIGNED task from the SWAP queue. + TaskState state; + auto assigned_task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + if (status.ok()) { auto spec = assigned_task.GetTaskSpecification(); // We successfully assigned the task to the worker. @@ -1761,7 +1929,9 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // Release task's resources. The worker's lifetime resources are still held. auto const &task_resources = worker.GetTaskResourceIds(); - local_available_resources_.Release(task_resources); + const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); + local_available_resources_.ReleaseConstrained( + task_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( task_resources.ToResourceSet()); worker.ResetTaskResourceIds(); @@ -2050,9 +2220,9 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, /// TODO(rkn): Should we check that the node manager is remote and not local? /// TODO(rkn): Should we check if the remote node manager is known to be dead? // Attempt to forward the task. - ForwardTask(task, node_manager_id, [this, task, node_manager_id](ray::Status error) { + ForwardTask(task, node_manager_id, [this, node_manager_id](ray::Status error, + const Task &task) { const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " << node_manager_id; @@ -2074,14 +2244,22 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, RayConfig::instance().node_manager_forward_task_retry_timeout_milliseconds()); retry_timer->expires_from_now(retry_duration); retry_timer->async_wait( - [this, task, task_id, retry_timer](const boost::system::error_code &error) { + [this, task_id, retry_timer](const boost::system::error_code &error) { // Timer killing will receive the boost::asio::error::operation_aborted, // we only handle the timeout event. RAY_CHECK(!error); RAY_LOG(INFO) << "Resubmitting task " << task_id << " because ForwardTask failed."; + // Remove the RESUBMITTED task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + // Submit the task again. SubmitTask(task, Lineage()); }); + // Temporarily move the RESUBMITTED task to the SWAP queue while the + // timer is active. + local_queues_.QueueTasks({task}, TaskState::SWAP); // Remove the task from the lineage cache. The task will get added back // once it is resubmitted. lineage_cache_.RemoveWaitingTask(task_id); @@ -2094,8 +2272,9 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, }); } -void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, - const std::function &on_error) { +void NodeManager::ForwardTask( + const Task &task, const ClientID &node_id, + const std::function &on_error) { const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -2129,16 +2308,25 @@ void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, if (it == remote_server_connections_.end()) { // TODO(atumanov): caller must handle failure to ensure tasks are not lost. RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - on_error(ray::Status::IOError("NodeManager connection not found")); + on_error(ray::Status::IOError("NodeManager connection not found"), task); return; } - auto &server_conn = it->second; + // Move the FORWARDING task to the SWAP queue so that we remember that we + // have it queued locally. Once the ForwardTaskRequest has been sent, the + // task will get re-queued, depending on whether the message succeeded or + // not. + local_queues_.QueueTasks({task}, TaskState::SWAP); server_conn->WriteMessageAsync( static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer(), - [this, on_error, task_id, node_id, spec](ray::Status status) { + fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) { + // Remove the FORWARDING task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + if (status.ok()) { + const auto &spec = task.GetTaskSpecification(); // If we were able to forward the task, remove the forwarded task from the // lineage cache since the receiving node is now responsible for writing // the task to the GCS. @@ -2173,7 +2361,7 @@ void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, } } } else { - on_error(status); + on_error(status, task); } }); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index edd456dbecce..576ffbc23f72 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -120,6 +120,21 @@ class NodeManager { /// \return Void. void ClientRemoved(const ClientTableDataT &client_data); + /// Handler for the addition or updation of a resource in the GCS + /// \param client_data Data associated with the new client. + /// \return Void. + void ResourceCreateUpdated(const ClientTableDataT &client_data); + + /// Handler for the deletion of a resource in the GCS + /// \param client_data Data associated with the new client. + /// \return Void. + void ResourceDeleted(const ClientTableDataT &client_data); + + /// Evaluates the local infeasible queue to check if any tasks can be scheduled. + /// This is called whenever there's an update to the resources on the local client. + /// \return Void. + void TryLocalInfeasibleTaskScheduling(); + /// Send heartbeats to the GCS. void Heartbeat(); @@ -231,8 +246,9 @@ class NodeManager { /// \param task The task to forward. /// \param node_id The ID of the node to forward the task to. /// \param on_error Callback on run on non-ok status. - void ForwardTask(const Task &task, const ClientID &node_id, - const std::function &on_error); + void ForwardTask( + const Task &task, const ClientID &node_id, + const std::function &on_error); /// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to /// "running" task state. @@ -413,6 +429,13 @@ class NodeManager { /// \param task The task that just finished. void UpdateActorFrontier(const Task &task); + /// Process client message of SetResourceRequest + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessSetResourceRequest(const std::shared_ptr &client, + const uint8_t *message_data); + /// Handle the case where an actor is disconnected, determine whether this /// actor needs to be reconstructed and then update actor table. /// This function needs to be called either when actor process dies or when diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 09e9b5fed5e2..0f488089e6d0 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -386,3 +386,13 @@ ray::Status RayletClient::NotifyActorResumedFromCheckpoint( return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); } + +ray::Status RayletClient::SetResource(const std::string &resource_name, + const double capacity, + const ray::ClientID &client_Id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateSetResourceRequest( + fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index ff66ff4621b0..0bdd076b5577 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -165,6 +165,14 @@ class RayletClient { ray::Status NotifyActorResumedFromCheckpoint(const ActorID &actor_id, const ActorCheckpointID &checkpoint_id); + /// Sets a resource with the specified capacity and client id + /// \param resource_name Name of the resource to be set + /// \param capacity Capacity of the resource + /// \param client_Id ClientID where the resource is to be set + /// \return ray::Status + ray::Status SetResource(const std::string &resource_name, const double capacity, + const ray::ClientID &client_Id); + Language GetLanguage() const { return language_; } ClientID GetClientID() const { return client_id_; } diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 4274ff5a2018..d1a648a34ce4 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -171,7 +171,7 @@ void ReconstructionPolicy::HandleTaskLeaseNotification(const TaskID &task_id, } void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) { - TaskID task_id = ComputeTaskId(object_id); + TaskID task_id = object_id.task_id(); auto it = listening_tasks_.find(task_id); // Add this object to the list of objects created by the same task. if (it == listening_tasks_.end()) { @@ -185,7 +185,7 @@ void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) } void ReconstructionPolicy::Cancel(const ObjectID &object_id) { - TaskID task_id = ComputeTaskId(object_id); + TaskID task_id = object_id.task_id(); auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { // We already stopped listening for this task. diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index d9fb92388aa6..7f8887b15372 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -224,8 +224,7 @@ class ReconstructionPolicyTest : public ::testing::Test { TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -243,8 +242,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); mock_object_directory_->SetObjectLocations(object_id, {ClientID::from_random()}); // Listen for both objects. @@ -267,8 +265,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); ClientID client_id = ClientID::from_random(); mock_object_directory_->SetObjectLocations(object_id, {client_id}); @@ -292,9 +289,8 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { // Create two object IDs produced by the same task. TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id1 = ComputeReturnId(task_id, 1); - ObjectID object_id2 = ComputeReturnId(task_id, 2); + ObjectID object_id1 = ObjectID::for_task_return(task_id, 1); + ObjectID object_id2 = ObjectID::for_task_return(task_id, 2); // Listen for both objects. reconstruction_policy_->ListenAndMaybeReconstruct(object_id1); @@ -313,8 +309,7 @@ TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Run the test for much longer than the reconstruction timeout. int64_t test_period = 2 * reconstruction_timeout_ms_; @@ -340,8 +335,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -368,8 +362,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -395,8 +388,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 986ede199d38..29af345b8391 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -9,7 +9,8 @@ namespace { static constexpr const char *task_state_strings[] = { "placeable", "waiting", "ready", - "running", "infeasible", "waiting_for_actor_creation"}; + "running", "infeasible", "waiting for actor creation", + "swap"}; static_assert(sizeof(task_state_strings) / sizeof(const char *) == static_cast(ray::raylet::TaskState::kNumTaskQueues), "Must specify a TaskState name for every task queue"); @@ -172,6 +173,9 @@ void SchedulingQueue::FilterState(std::unordered_set &task_ids, case TaskState::INFEASIBLE: FilterStateFromQueue(task_ids, TaskState::INFEASIBLE); break; + case TaskState::SWAP: + FilterStateFromQueue(task_ids, TaskState::SWAP); + break; case TaskState::BLOCKED: { const auto blocked_ids = GetBlockedTaskIds(); for (auto it = task_ids.begin(); it != task_ids.end();) { @@ -230,7 +234,7 @@ std::vector SchedulingQueue::RemoveTasks(std::unordered_set &task_ // Try to find the tasks to remove from the queues. for (const auto &task_state : { TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_ids, &removed_tasks); } @@ -245,7 +249,7 @@ Task SchedulingQueue::RemoveTask(const TaskID &task_id, TaskState *removed_task_ // Try to find the task to remove in the queues. for (const auto &task_state : { TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_id_set, &removed_tasks); if (task_id_set.empty()) { @@ -260,7 +264,7 @@ Task SchedulingQueue::RemoveTask(const TaskID &task_id, TaskState *removed_task_ } // Make sure we got the removed task. - RAY_CHECK(removed_tasks.size() == 1); + RAY_CHECK(removed_tasks.size() == 1) << task_id; const auto &task = removed_tasks.front(); RAY_CHECK(task.GetTaskSpecification().TaskId() == task_id); return task; @@ -287,6 +291,9 @@ void SchedulingQueue::MoveTasks(std::unordered_set &task_ids, TaskState case TaskState::INFEASIBLE: RemoveTasksFromQueue(TaskState::INFEASIBLE, task_ids, &removed_tasks); break; + case TaskState::SWAP: + RemoveTasksFromQueue(TaskState::SWAP, task_ids, &removed_tasks); + break; default: RAY_LOG(FATAL) << "Attempting to move tasks from unrecognized state " << static_cast::type>(src_state); @@ -312,6 +319,9 @@ void SchedulingQueue::MoveTasks(std::unordered_set &task_ids, TaskState case TaskState::INFEASIBLE: QueueTasks(removed_tasks, TaskState::INFEASIBLE); break; + case TaskState::SWAP: + QueueTasks(removed_tasks, TaskState::SWAP); + break; default: RAY_LOG(FATAL) << "Attempting to move tasks to unrecognized state " << static_cast::type>(dst_state); @@ -348,8 +358,16 @@ std::unordered_set SchedulingQueue::GetTaskIdsForDriver( std::unordered_set SchedulingQueue::GetTaskIdsForActor( const ActorID &actor_id) const { std::unordered_set task_ids; + int swap = static_cast(TaskState::SWAP); + int i = 0; for (const auto &task_queue : task_queues_) { - GetActorTasksFromQueue(*task_queue, actor_id, task_ids); + // This is a hack to make sure that we don't remove tasks from the SWAP + // queue, since these are always guaranteed to be removed and eventually + // resubmitted if necessary by the node manager. + if (i != swap) { + GetActorTasksFromQueue(*task_queue, actor_id, task_ids); + } + i++; } return task_ids; } @@ -385,10 +403,8 @@ const std::unordered_set &SchedulingQueue::GetDriverTaskIds() const { std::string SchedulingQueue::DebugString() const { std::stringstream result; result << "SchedulingQueue:"; - for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, - }) { + for (size_t i = 0; i < static_cast(ray::raylet::TaskState::kNumTaskQueues); i++) { + TaskState task_state = static_cast(i); result << "\n- num " << GetTaskStateString(task_state) << " tasks: " << GetTaskQueue(task_state)->GetTasks().size(); } @@ -397,10 +413,8 @@ std::string SchedulingQueue::DebugString() const { } void SchedulingQueue::RecordMetrics() const { - for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, - }) { + for (size_t i = 0; i < static_cast(ray::raylet::TaskState::kNumTaskQueues); i++) { + TaskState task_state = static_cast(i); stats::SchedulingQueueStats().Record( static_cast(GetTaskQueue(task_state)->GetTasks().size()), {{stats::ValueTypeKey, diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 3f7bab1233cb..4fd07e5ca606 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -33,6 +33,13 @@ enum class TaskState { // The task is an actor method and is waiting to learn where the actor was // created. WAITING_FOR_ACTOR_CREATION, + // Swap queue for tasks that are in between states. This can happen when a + // task is removed from one queue, and an async callback is responsible for + // re-queuing the task. For example, a READY task that has just been assigned + // to a worker will get moved to the SWAP queue while waiting for a response + // from the worker. If the worker accepts the task, the task will be added to + // the RUNNING queue, else it will be returned to READY. + SWAP, // The number of task queues. All states that precede this enum must have an // associated TaskQueue in SchedulingQueue. All states that succeed // this enum do not have an associated TaskQueue, since the tasks @@ -144,7 +151,7 @@ class SchedulingQueue { for (const auto &task_state : { TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, TaskState::INFEASIBLE, - TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, }) { if (task_state == TaskState::READY) { task_queues_[static_cast(task_state)] = ready_queue_; diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 4fbebb8df79f..dc24c95d46e4 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -24,7 +24,7 @@ bool TaskDependencyManager::CheckObjectLocal(const ObjectID &object_id) const { } bool TaskDependencyManager::CheckObjectRequired(const ObjectID &object_id) const { - const TaskID task_id = ComputeTaskId(object_id); + const TaskID task_id = object_id.task_id(); auto task_entry = required_tasks_.find(task_id); // If there are no subscribed tasks that are dependent on the object, then do // nothing. @@ -82,7 +82,7 @@ std::vector TaskDependencyManager::HandleObjectLocal( // Find any tasks that are dependent on the newly available object. std::vector ready_task_ids; - auto creating_task_entry = required_tasks_.find(ComputeTaskId(object_id)); + auto creating_task_entry = required_tasks_.find(object_id.task_id()); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); if (object_entry != creating_task_entry->second.end()) { @@ -113,7 +113,7 @@ std::vector TaskDependencyManager::HandleObjectMissing( // Find any tasks that are dependent on the missing object. std::vector waiting_task_ids; - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); auto creating_task_entry = required_tasks_.find(creating_task_id); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); @@ -149,7 +149,7 @@ bool TaskDependencyManager::SubscribeDependencies( auto inserted = task_entry.object_dependencies.insert(object_id); if (inserted.second) { // Get the ID of the task that creates the dependency. - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); // Determine whether the dependency can be fulfilled by the local node. if (local_objects_.count(object_id) == 0) { // The object is not local. @@ -186,7 +186,7 @@ bool TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { // Remove the task from the list of tasks that are dependent on this // object. // Get the ID of the task that creates the dependency. - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); auto creating_task_entry = required_tasks_.find(creating_task_id); std::vector &dependent_tasks = creating_task_entry->second[object_id]; auto it = std::find(dependent_tasks.begin(), dependent_tasks.end(), task_id); @@ -324,7 +324,7 @@ void TaskDependencyManager::RemoveTasksAndRelatedObjects( // Cancel all of the objects that were required by the removed tasks. for (const auto &object_id : required_objects) { - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); required_tasks_.erase(creating_task_id); HandleRemoteDependencyCanceled(object_id); } diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 5126d82555af..62bbf17069d5 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -266,7 +266,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { TEST_F(TaskDependencyManagerTest, TestDependentPut) { // Create a task with 3 arguments. auto task1 = ExampleTask({}, 0); - ObjectID put_id = ComputePutId(task1.GetTaskSpecification().TaskId(), 1); + ObjectID put_id = ObjectID::for_put(task1.GetTaskSpecification().TaskId(), 1); auto task2 = ExampleTask({put_id}, 0); // No objects have been registered in the task dependency manager, so the put diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 5f301c47c1c3..17a8b185fc78 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -92,20 +92,14 @@ TaskSpecification::TaskSpecification( arguments.push_back(argument->ToFlatbuffer(fbb)); } - // Generate return ids. - std::vector returns; - for (int64_t i = 1; i < num_returns + 1; ++i) { - returns.push_back(ComputeReturnId(task_id, i)); - } - // Serialize the TaskSpecification. auto spec = CreateTaskInfo( fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id), to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, - ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), - ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), + ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, + map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); @@ -167,12 +161,12 @@ int64_t TaskSpecification::NumArgs() const { int64_t TaskSpecification::NumReturns() const { auto message = flatbuffers::GetRoot(spec_.data()); - return (message->returns()->size() / kUniqueIDSize); + return message->num_returns(); } ObjectID TaskSpecification::ReturnId(int64_t return_index) const { auto message = flatbuffers::GetRoot(spec_.data()); - return ids_from_flatbuf(*message->returns())[return_index]; + return ObjectID::for_task_return(TaskId(), return_index + 1); } bool TaskSpecification::ArgByRef(int64_t arg_index) const { diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 9f3545bdf638..6d0cfa37017a 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -1,5 +1,6 @@ #include "gtest/gtest.h" +#include "ray/common/common_protocol.h" #include "ray/raylet/task_spec.h" namespace ray { @@ -9,21 +10,21 @@ namespace raylet { void TestTaskReturnId(const TaskID &task_id, int64_t return_index) { // Round trip test for computing the object ID for a task's return value, // then computing the task ID that created the object. - ObjectID return_id = ComputeReturnId(task_id, return_index); - ASSERT_EQ(ComputeTaskId(return_id), task_id); - ASSERT_EQ(ComputeObjectIndex(return_id), return_index); + ObjectID return_id = ObjectID::for_task_return(task_id, return_index); + ASSERT_EQ(return_id.task_id(), task_id); + ASSERT_EQ(return_id.object_index(), return_index); } void TestTaskPutId(const TaskID &task_id, int64_t put_index) { // Round trip test for computing the object ID for a task's put value, then // computing the task ID that created the object. - ObjectID put_id = ComputePutId(task_id, put_index); - ASSERT_EQ(ComputeTaskId(put_id), task_id); - ASSERT_EQ(ComputeObjectIndex(put_id), -1 * put_index); + ObjectID put_id = ObjectID::for_put(task_id, put_index); + ASSERT_EQ(put_id.task_id(), task_id); + ASSERT_EQ(put_id.object_index(), -1 * put_index); } TEST(TaskSpecTest, TestTaskReturnIds) { - TaskID task_id = FinishTaskId(TaskID::from_random()); + TaskID task_id = TaskID::from_random(); // Check that we can compute between a task ID and the object IDs of its // return values and puts. @@ -35,6 +36,68 @@ TEST(TaskSpecTest, TestTaskReturnIds) { TestTaskPutId(task_id, kMaxTaskPuts); } +TEST(IdPropertyTest, TestIdProperty) { + TaskID task_id = TaskID::from_random(); + ASSERT_EQ(task_id, TaskID::from_binary(task_id.binary())); + ObjectID object_id = ObjectID::from_random(); + ASSERT_EQ(object_id, ObjectID::from_binary(object_id.binary())); + + ASSERT_TRUE(TaskID().is_nil()); + ASSERT_TRUE(TaskID::nil().is_nil()); + ASSERT_TRUE(ObjectID().is_nil()); + ASSERT_TRUE(ObjectID::nil().is_nil()); +} + +TEST(TaskSpecTest, TaskInfoSize) { + std::vector references = {ObjectID::from_random(), ObjectID::from_random()}; + auto arguments_1 = std::make_shared(references); + std::string one_arg("This is an value argument."); + auto arguments_2 = std::make_shared( + reinterpret_cast(one_arg.c_str()), one_arg.size()); + std::vector> task_arguments({arguments_1, arguments_2}); + auto task_id = TaskID::from_random(); + { + flatbuffers::FlatBufferBuilder fbb; + std::vector> arguments; + for (auto &argument : task_arguments) { + arguments.push_back(argument->ToFlatbuffer(fbb)); + } + // General task. + auto spec = CreateTaskInfo( + fbb, to_flatbuf(fbb, DriverID::from_random()), to_flatbuf(fbb, task_id), + to_flatbuf(fbb, TaskID::from_random()), 0, to_flatbuf(fbb, ActorID::nil()), + to_flatbuf(fbb, ObjectID::nil()), 0, to_flatbuf(fbb, ActorID::nil()), + to_flatbuf(fbb, ActorHandleID::nil()), 0, + ids_to_flatbuf(fbb, std::vector()), fbb.CreateVector(arguments), 1, + map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), Language::PYTHON, + string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); + fbb.Finish(spec); + RAY_LOG(ERROR) << "Ordinary task info size: " << fbb.GetSize(); + } + + { + flatbuffers::FlatBufferBuilder fbb; + std::vector> arguments; + for (auto &argument : task_arguments) { + arguments.push_back(argument->ToFlatbuffer(fbb)); + } + // General task. + auto spec = CreateTaskInfo( + fbb, to_flatbuf(fbb, DriverID::from_random()), to_flatbuf(fbb, task_id), + to_flatbuf(fbb, TaskID::from_random()), 10, + to_flatbuf(fbb, ActorID::from_random()), to_flatbuf(fbb, ObjectID::from_random()), + 10000000, to_flatbuf(fbb, ActorID::from_random()), + to_flatbuf(fbb, ActorHandleID::from_random()), 20, + ids_to_flatbuf(fbb, std::vector( + {ObjectID::from_random(), ObjectID::from_random()})), + fbb.CreateVector(arguments), 2, map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), + Language::PYTHON, + string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); + fbb.Finish(spec); + RAY_LOG(ERROR) << "Actor task info size: " << fbb.GetSize(); + } +} + } // namespace raylet } // namespace ray