diff --git a/.bazelrc b/.bazelrc index 488b33101594..3e3c3b6c4fa4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,3 +2,5 @@ build --compilation_mode=opt build --action_env=PATH build --action_env=PYTHON_BIN_PATH +# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341 +build --per_file_copt="external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD" diff --git a/.travis.yml b/.travis.yml index 96bc82b24a63..49cb31aedca6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,7 @@ language: generic +dist: xenial + services: - docker diff --git a/BUILD.bazel b/BUILD.bazel index 90b0f536a10b..da36eec0cf57 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,12 +1,37 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] +# Node manager gRPC lib. +grpc_proto_library( + name = "node_manager_grpc_lib", + srcs = ["src/ray/protobuf/node_manager.proto"], +) + +# Node manager server and client. +cc_library( + name = "node_manager_rpc_lib", + srcs = glob([ + "src/ray/rpc/*.cc", + ]), + hdrs = glob([ + "src/ray/rpc/*.h", + ]), + copts = COPTS, + deps = [ + ":node_manager_grpc_lib", + ":ray_common", + "@boost//:asio", + "@com_github_grpc_grpc//:grpc++", + ], +) + cc_binary( name = "raylet", srcs = ["src/ray/raylet/main.cc"], @@ -89,6 +114,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", + ":node_manager_rpc_lib", ":object_manager", ":ray_common", ":ray_util", @@ -111,13 +137,18 @@ cc_library( srcs = glob( [ "src/ray/core_worker/*.cc", + "src/ray/core_worker/store_provider/*.cc", + "src/ray/core_worker/transport/*.cc", ], exclude = [ "src/ray/core_worker/*_test.cc", + "src/ray/core_worker/mock_worker.cc", ], ), hdrs = glob([ "src/ray/core_worker/*.h", + "src/ray/core_worker/store_provider/*.h", + "src/ray/core_worker/transport/*.h", ]), copts = COPTS, deps = [ @@ -127,7 +158,15 @@ cc_library( ], ) -# This test is run by src/ray/test/run_core_worker_tests.sh +cc_binary( + name = "mock_worker", + srcs = ["src/ray/core_worker/mock_worker.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + ], +) + cc_binary( name = "core_worker_test", srcs = ["src/ray/core_worker/core_worker_test.cc"], @@ -535,7 +574,7 @@ flatbuffer_py_library( "ErrorTableData.py", "ErrorType.py", "FunctionTableData.py", - "GcsTableEntry.py", + "GcsEntry.py", "HeartbeatBatchTableData.py", "HeartbeatTableData.py", "Language.py", diff --git a/README.rst b/README.rst index 06dd8115fdf3..60a4f5043019 100644 --- a/README.rst +++ b/README.rst @@ -6,7 +6,7 @@ .. image:: https://readthedocs.org/projects/ray/badge/?version=latest :target: http://ray.readthedocs.io/en/latest/?badge=latest -.. image:: https://img.shields.io/badge/pypi-0.7.0-blue.svg +.. image:: https://img.shields.io/badge/pypi-0.7.1-blue.svg :target: https://pypi.org/project/ray/ | diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 5598d5820e35..3e1e1838a59a 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -3,6 +3,8 @@ load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_repositories") load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + def ray_deps_build_all(): gen_java_deps() @@ -10,3 +12,5 @@ def ray_deps_build_all(): boost_deps() prometheus_cpp_repositories() python_configure(name = "local_config_python") + grpc_deps() + diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index b3cd21b9b3b1..e6dc21585699 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -101,3 +101,11 @@ def ray_deps_setup(): # `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged. urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"], ) + + http_archive( + name = "com_github_grpc_grpc", + urls = [ + "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + ], + strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + ) diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index f723d5122981..7962b21075c0 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index a97bf5517ea2..13036ae7da0f 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -392,6 +392,16 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/eager_execution.py --iters=2 + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run PPO \ + --stop '{"training_iteration": 1}' \ + --config '{"use_eager": true, "simple_optimizer": true}' + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2 diff --git a/ci/stress_tests/application_cluster_template.yaml b/ci/stress_tests/application_cluster_template.yaml index 541419da55af..9218c2cf7356 100644 --- a/ci/stress_tests/application_cluster_template.yaml +++ b/ci/stress_tests/application_cluster_template.yaml @@ -37,7 +37,7 @@ provider: # Availability zone(s), comma-separated, that nodes may be launched in. # Nodes are currently spread between zones by a round-robin approach, # however this implementation detail should not be relied upon. - availability_zone: us-west-2a,us-west-2b + availability_zone: us-west-2b # How Ray will authenticate with newly launched nodes. auth: @@ -90,8 +90,8 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_<<>>/bin:$PATH"' >> ~/.bashrc - - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl - - rllib || pip install -U ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl[rllib] + - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/releases/<<>>/<<>>/ray-<<>>-<<>>-manylinux1_x86_64.whl + - rllib || pip install -U ray-<<>>-<<>>-manylinux1_x86_64.whl[rllib] - pip install tensorflow-gpu==1.12.0 - echo "sudo halt" | at now + 60 minutes # Consider uncommenting these if you also want to run apt-get commands during setup diff --git a/ci/stress_tests/run_application_stress_tests.sh b/ci/stress_tests/run_application_stress_tests.sh index a8ded40fa797..293530928745 100755 --- a/ci/stress_tests/run_application_stress_tests.sh +++ b/ci/stress_tests/run_application_stress_tests.sh @@ -1,4 +1,11 @@ #!/usr/bin/env bash + +# This script should be run as follows: +# ./run_application_stress_tests.sh +# For example, might be 0.7.1 +# and might be bc3b6efdb6933d410563ee70f690855c05f25483. The commit +# should be the latest commit on the branch "releases/". + # This script runs all of the application tests. # Currently includes an IMPALA stress test and a SGD stress test. # on both Python 2.7 and 3.6. @@ -10,26 +17,39 @@ # This script will exit with code 1 if the test did not run successfully. +# Show explicitly which commands are currently running. This should only be AFTER +# the private key is placed. +set -x ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) RESULT_FILE=$ROOT_DIR/"results-$(date '+%Y-%m-%d_%H-%M-%S').log" -echo "Logging to" $RESULT_FILE -echo -e $RAY_AWS_SSH_KEY > /root/.ssh/ray-autoscaler_us-west-2.pem && chmod 400 /root/.ssh/ray-autoscaler_us-west-2.pem || true +touch "$RESULT_FILE" +echo "Logging to" "$RESULT_FILE" +if [[ -z "$1" ]]; then + echo "ERROR: The first argument must be the Ray version string." + exit 1 +else + RAY_VERSION=$1 +fi -# Show explicitly which commands are currently running. This should only be AFTER -# the private key is placed. -set -x +if [[ -z "$2" ]]; then + echo "ERROR: The second argument must be the commit hash to test." + exit 1 +else + RAY_COMMIT=$2 +fi -touch $RESULT_FILE +echo "Testing ray==$RAY_VERSION at commit $RAY_COMMIT." +echo "The wheels used will live under https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_COMMIT/" # This function identifies the right string for the Ray wheel. _find_wheel_str(){ local python_version=$1 # echo "PYTHON_VERSION", $python_version local wheel_str="" - if [ $python_version == "p27" ]; then + if [ "$python_version" == "p27" ]; then wheel_str="cp27-cp27mu" else wheel_str="cp36-cp36m" @@ -41,7 +61,7 @@ _find_wheel_str(){ # Actual test runtime is roughly 10 minutes. test_impala(){ local PYTHON_VERSION=$1 - local WHEEL_STR=$(_find_wheel_str $PYTHON_VERSION) + local WHEEL_STR=$(_find_wheel_str "$PYTHON_VERSION") pushd "$ROOT_DIR" local TEST_NAME="rllib_impala_$PYTHON_VERSION" @@ -50,32 +70,34 @@ test_impala(){ cat application_cluster_template.yaml | sed -e " + s/<<>>/$RAY_VERSION/g; + s/<<>>/$RAY_COMMIT/; s/<<>>/$TEST_NAME/; - s/<<>>/g3.16xlarge/; + s/<<>>/p3.16xlarge/; s/<<>>/m5.24xlarge/; s/<<>>/5/; s/<<>>/5/; s/<<>>/$PYTHON_VERSION/; - s/<<>>/$WHEEL_STR/;" > $CLUSTER + s/<<>>/$WHEEL_STR/;" > "$CLUSTER" echo "Try running IMPALA stress test." { RLLIB_DIR=../../python/ray/rllib/ - ray --logging-level=DEBUG up -y $CLUSTER && - ray rsync_up $CLUSTER $RLLIB_DIR/tuned_examples/ tuned_examples/ && + ray --logging-level=DEBUG up -y "$CLUSTER" && + ray rsync_up "$CLUSTER" $RLLIB_DIR/tuned_examples/ tuned_examples/ && sleep 1 && - ray --logging-level=DEBUG exec $CLUSTER "rllib || true" && - ray --logging-level=DEBUG exec $CLUSTER " + ray --logging-level=DEBUG exec "$CLUSTER" "rllib || true" && + ray --logging-level=DEBUG exec "$CLUSTER" " rllib train -f tuned_examples/atari-impala-large.yaml --redis-address='localhost:6379' --queue-trials" && - echo "PASS: IMPALA Test for" $PYTHON_VERSION >> $RESULT_FILE - } || echo "FAIL: IMPALA Test for" $PYTHON_VERSION >> $RESULT_FILE + echo "PASS: IMPALA Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" + } || echo "FAIL: IMPALA Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" # Tear down cluster. if [ "$DEBUG_MODE" = "" ]; then - ray down -y $CLUSTER - rm $CLUSTER + ray down -y "$CLUSTER" + rm "$CLUSTER" else - echo "Not tearing down cluster" $CLUSTER + echo "Not tearing down cluster" "$CLUSTER" fi popd } @@ -93,32 +115,34 @@ test_sgd(){ cat application_cluster_template.yaml | sed -e " + s/<<>>/$RAY_VERSION/g; + s/<<>>/$RAY_COMMIT/; s/<<>>/$TEST_NAME/; - s/<<>>/g3.16xlarge/; - s/<<>>/g3.16xlarge/; + s/<<>>/p3.16xlarge/; + s/<<>>/p3.16xlarge/; s/<<>>/3/; s/<<>>/3/; s/<<>>/$PYTHON_VERSION/; - s/<<>>/$WHEEL_STR/;" > $CLUSTER + s/<<>>/$WHEEL_STR/;" > "$CLUSTER" echo "Try running SGD stress test." { SGD_DIR=$ROOT_DIR/../../python/ray/experimental/sgd/ - ray --logging-level=DEBUG up -y $CLUSTER && + ray --logging-level=DEBUG up -y "$CLUSTER" && # TODO: fix submit so that args work - ray rsync_up $CLUSTER $SGD_DIR/mnist_example.py mnist_example.py && + ray rsync_up "$CLUSTER" "$SGD_DIR/mnist_example.py" mnist_example.py && sleep 1 && - ray --logging-level=DEBUG exec $CLUSTER " + ray --logging-level=DEBUG exec "$CLUSTER" " python mnist_example.py --redis-address=localhost:6379 --num-iters=2000 --num-workers=8 --devices-per-worker=2 --gpu" && - echo "PASS: SGD Test for" $PYTHON_VERSION >> $RESULT_FILE - } || echo "FAIL: SGD Test for" $PYTHON_VERSION >> $RESULT_FILE + echo "PASS: SGD Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" + } || echo "FAIL: SGD Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" # Tear down cluster. if [ "$DEBUG_MODE" = "" ]; then - ray down -y $CLUSTER - rm $CLUSTER + ray down -y "$CLUSTER" + rm "$CLUSTER" else - echo "Not tearing down cluster" $CLUSTER + echo "Not tearing down cluster" "$CLUSTER" fi popd } @@ -130,6 +154,6 @@ do test_sgd $PYTHON_VERSION done -cat $RESULT_FILE -cat $RESULT_FILE | grep FAIL > test.log +cat "$RESULT_FILE" +cat "$RESULT_FILE" | grep FAIL > test.log [ ! -s test.log ] || exit 1 diff --git a/ci/stress_tests/run_stress_tests.sh b/ci/stress_tests/run_stress_tests.sh index 1d4d102092ee..f92e8c592d40 100755 --- a/ci/stress_tests/run_stress_tests.sh +++ b/ci/stress_tests/run_stress_tests.sh @@ -1,40 +1,61 @@ #!/usr/bin/env bash +# Show explicitly which commands are currently running. +set -x + ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) RESULT_FILE=$ROOT_DIR/results-$(date '+%Y-%m-%d_%H-%M-%S').log -echo "Logging to" $RESULT_FILE -echo -e $RAY_AWS_SSH_KEY > /root/.ssh/ray-autoscaler_us-west-2.pem && chmod 400 /root/.ssh/ray-autoscaler_us-west-2.pem || true +touch "$RESULT_FILE" +echo "Logging to" "$RESULT_FILE" -# Show explicitly which commands are currently running. This should only be AFTER -# the private key is placed. -set -x +if [[ -z "$1" ]]; then + echo "ERROR: The first argument must be the Ray version string." + exit 1 +else + RAY_VERSION=$1 +fi -touch $RESULT_FILE +if [[ -z "$2" ]]; then + echo "ERROR: The second argument must be the commit hash to test." + exit 1 +else + RAY_COMMIT=$2 +fi + +echo "Testing ray==$RAY_VERSION at commit $RAY_COMMIT." +echo "The wheels used will live under https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_COMMIT/" run_test(){ local test_name=$1 - local CLUSTER="stress_testing_config.yaml" + local CLUSTER="stress_testing_config_temporary.yaml" + + cat stress_testing_config.yaml | + sed -e " + s/<<>>/$RAY_VERSION/g; + s/<<>>/$RAY_COMMIT/;" > "$CLUSTER" + echo "Try running $test_name." { ray up -y $CLUSTER --cluster-name "$test_name" && sleep 1 && - ray --logging-level=DEBUG submit $CLUSTER --cluster-name "$test_name" "$test_name.py" - } || echo "FAIL: $test_name" >> $RESULT_FILE + ray --logging-level=DEBUG submit "$CLUSTER" --cluster-name "$test_name" "$test_name.py" + } || echo "FAIL: $test_name" >> "$RESULT_FILE" # Tear down cluster. if [ "$DEBUG_MODE" = "" ]; then ray down -y $CLUSTER --cluster-name "$test_name" + rm "$CLUSTER" else - echo "Not tearing down cluster" $CLUSTER + echo "Not tearing down cluster" "$CLUSTER" fi } pushd "$ROOT_DIR" - run_test test_many_tasks_and_transfers + run_test test_many_tasks run_test test_dead_actors popd -cat $RESULT_FILE -[ ! -s $RESULT_FILE ] || exit 1 +cat "$RESULT_FILE" +[ ! -s "$RESULT_FILE" ] || exit 1 diff --git a/ci/stress_tests/stress_testing_config.yaml b/ci/stress_tests/stress_testing_config.yaml index f71ae8f2dc18..ae878963094f 100644 --- a/ci/stress_tests/stress_testing_config.yaml +++ b/ci/stress_tests/stress_testing_config.yaml @@ -101,7 +101,7 @@ setup_commands: # - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/releases/<<>>/<<>>/ray-<<>>-cp36-cp36m-manylinux1_x86_64.whl - echo "sudo halt" | at now + 60 minutes # Custom commands that will be run on the head node after common setup. diff --git a/ci/stress_tests/test_many_tasks_and_transfers.py b/ci/stress_tests/test_many_tasks.py similarity index 100% rename from ci/stress_tests/test_many_tasks_and_transfers.py rename to ci/stress_tests/test_many_tasks.py diff --git a/ci/suppress_output b/ci/suppress_output index 623559d11cbc..0f32b1a88b37 100755 --- a/ci/suppress_output +++ b/ci/suppress_output @@ -23,7 +23,7 @@ time "$@" >$TMPFILE 2>&1 CODE=$? if [ $CODE != 0 ]; then - cat $TMPFILE + tail -n 2000 $TMPFILE echo "FAILED $CODE" kill $WATCHDOG_PID exit $CODE diff --git a/ci/travis/check-git-clang-format-output.sh b/ci/travis/check-git-clang-format-output.sh index 4209811cd21c..6d83044c4877 100755 --- a/ci/travis/check-git-clang-format-output.sh +++ b/ci/travis/check-git-clang-format-output.sh @@ -8,7 +8,7 @@ else base_commit="$TRAVIS_BRANCH" echo "Running clang-format against branch $base_commit, with hash $(git rev-parse $base_commit)" fi -output="$(ci/travis/git-clang-format --binary clang-format-3.8 --commit $base_commit --diff --exclude '(.*thirdparty/|.*redismodule.h|.*.js|.*.java)')" +output="$(ci/travis/git-clang-format --binary clang-format --commit $base_commit --diff --exclude '(.*thirdparty/|.*redismodule.h|.*.js|.*.java)')" if [ "$output" == "no modified files to format" ] || [ "$output" == "clang-format did not modify any files" ] ; then echo "clang-format passed." exit 0 diff --git a/ci/travis/install-bazel.sh b/ci/travis/install-bazel.sh index c9614f7722ef..5b6d9572952e 100755 --- a/ci/travis/install-bazel.sh +++ b/ci/travis/install-bazel.sh @@ -16,7 +16,7 @@ else exit 1 fi -URL="https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-${platform}-x86_64.sh" +URL="https://github.com/bazelbuild/bazel/releases/download/0.26.1/bazel-0.26.1-installer-${platform}-x86_64.sh" wget -O install.sh $URL chmod +x install.sh ./install.sh --user diff --git a/dev/RELEASE_PROCESS.rst b/dev/RELEASE_PROCESS.rst index 62862506e1ed..3b78cef5eda5 100644 --- a/dev/RELEASE_PROCESS.rst +++ b/dev/RELEASE_PROCESS.rst @@ -6,38 +6,45 @@ This document describes the process for creating new releases. 1. **Increment the Python version:** Create a PR that increments the Python package version. See `this example`_. -2. **Download the Travis-built wheels:** Once Travis has completed the tests, - the wheels from this commit can be downloaded from S3 to do testing, etc. - The URL is structured like this: - ``https://s3-us-west-2.amazonaws.com/ray-wheels//`` - where ```` is replaced by the ID of the commit and the ```` - is the incremented version from the previous step. The ```` can - be determined by looking at the OS/Version matrix in the documentation_. - -3. **Create a release branch:** This branch should also have the same commit ID as the - previous two steps. In order to create the branch, locally checkout the commit ID - i.e. ``git checkout ``. Then checkout a new branch of the format - ``releases/``. The release number must match the increment in - the first step. Then push that branch to the ray repo: - ``git push upstream releases/``. +2. **Bump version on Ray master branch again:** Create a pull request to + increment the version of the master branch. The format of the new version is + as follows: + + New minor release (e.g., 0.7.0): Increment the minor version and append + ``.dev0`` to the version. For example, if the version of the new release is + 0.7.0, the master branch needs to be updated to 0.8.0.dev0. + + New micro release (e.g., 0.7.1): Increment the ``dev`` number, such that the + number after ``dev`` equals the micro version. For example, if the version + of the new release is 0.7.1, the master branch needs to be updated to + 0.8.0.dev1. + + This can be merged as soon as step 1 is complete. + +3. **Create a release branch:** Create the branch from the version bump PR (the + one from step 1, not step 2). In order to create the branch, locally checkout + the commit ID i.e., ``git checkout ``. Then checkout a new branch of + the format ``releases/``. Then push that branch to the ray + repo: ``git push upstream releases/``. 4. **Testing:** Before a release is created, significant testing should be done. - Run the scripts `ci/stress_tests/run_stress_tests.sh`_ and - `ci/stress_tests/run_application_stress_tests.sh`_ and make sure they - pass. You **MUST** modify the autoscaler config file and replace - ``<>`` and ``<>`` with the appropriate - values to test the correct wheels. This will use the autoscaler to start a bunch of - machines and run some tests. Any new stress tests should be added to this - script so that they will be run automatically for future release testing. - -5. **Resolve release-blockers:** Should any release blocking issues arise, - there are two ways these issues are resolved: A PR to patch the issue or a - revert commit that removes the breaking change from the release. In the case - of a PR, that PR should be created against master. Once it is merged, the - release master should ``git cherry-pick`` the commit to the release branch. - If the decision is to revert a commit that caused the release blocker, the - release master should ``git revert`` the commit to be reverted on the - release branch. Push these changes directly to the release branch. + Run the following scripts + + .. code-block:: bash + + ray/ci/stress_tests/run_stress_tests.sh + ray/ci/stress_tests/run_application_stress_tests.sh + + and make sure they pass. If they pass, it will be obvious that they passed. + This will use the autoscaler to start a bunch of machines and run some tests. + +5. **Resolve release-blockers:** If a release blocking issue arises, there are + two ways the issue can be resolved: 1) Fix the issue on the master branch and + cherry-pick the relevant commit (using ``git cherry-pick``) onto the release + branch. 2) Revert the commit that introduced the bug on the release branch + (using ``git revert``), but not on the master. + + These changes should then be pushed directly to the release branch. 6. **Download all the wheels:** Now the release is ready to begin final testing. The wheels are automatically uploaded to S3, even on the release @@ -47,20 +54,20 @@ This document describes the process for creating new releases. export RAY_HASH=... # e.g., 618147f57fb40368448da3b2fb4fd213828fa12b export RAY_VERSION=... # e.g., 0.7.0 - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27mu-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27m-macosx_10_6_intel.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-macosx_10_6_intel.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-macosx_10_6_intel.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27mu-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-macosx_10_6_intel.whl 7. **Final Testing:** Send a link to the wheels to the other contributors and - core members of the Ray project. Make sure the wheels are tested on Ubuntu, - Mac OSX 10.12, and Mac OSX 10.13+. This testing should verify that the - wheels are correct and that all release blockers have been resolved. Should - a new release blocker be found, repeat steps 5-7. + core members of the Ray project. Make sure the wheels are tested on Ubuntu + and MacOS (ideally multiple versions of Ubuntu and MacOS). This testing + should verify that the wheels are correct and that all release blockers have + been resolved. Should a new release blocker be found, repeat steps 5-7. 8. **Upload to PyPI Test:** Upload the wheels to the PyPI test site using ``twine`` (ask Robert to add you as a maintainer to the PyPI project). You'll @@ -68,11 +75,11 @@ This document describes the process for creating new releases. .. code-block:: bash - twine upload --repository-url https://test.pypi.org/legacy/ray/.whl/* + twine upload --repository-url https://test.pypi.org/legacy/ ray/.whl/* assuming that you've downloaded the wheels from the ``ray-wheels`` S3 bucket and put them in ``ray/.whl``, that you've installed ``twine`` through - ``pip``, and that you've made PyPI accounts. + ``pip``, and that you've created both PyPI accounts. Test that you can install the wheels with pip from the PyPI test repository with @@ -86,7 +93,7 @@ This document describes the process for creating new releases. installed by checking ``ray.__version__`` and ``ray.__file__``. Do this at least for MacOS and for Linux, as well as for Python 2 and Python - 3. Also do this for different versions of MacOS. + 3. 9. **Upload to PyPI:** Now that you've tested the wheels on the PyPI test repository, they can be uploaded to the main PyPI repository. Be careful, @@ -107,41 +114,31 @@ This document describes the process for creating new releases. finds the correct Ray version, and successfully runs some simple scripts on both MacOS and Linux as well as Python 2 and Python 3. -10. **Create a GitHub release:** Create a GitHub release through the `GitHub website`_. - The release should be created at the commit from the previous - step. This should include **release notes**. Copy the style and formatting - used by previous releases. Create a draft of the release notes containing - information about substantial changes/updates/bugfixes and their PR number. - Once you have a draft, make sure you solicit feedback from other Ray - developers before publishing. Use the following to get started: +10. **Create a GitHub release:** Create a GitHub release through the + `GitHub website`_. The release should be created at the commit from the + previous step. This should include **release notes**. Copy the style and + formatting used by previous releases. Create a draft of the release notes + containing information about substantial changes/updates/bugfixes and their + PR numbers. Once you have a draft, make sure you solicit feedback from other + Ray developers before publishing. Use the following to get started: .. code-block:: bash git pull origin master --tags git log $(git describe --tags --abbrev=0)..HEAD --pretty=format:"%s" | sort -11. **Bump version on Ray master branch:** Create a pull request to increment the - version of the master branch. The format of the new version is as follows: - - New minor release (e.g., 0.7.0): Increment the minor version and append ``.dev0`` to - the version. For example, if the version of the new release is 0.7.0, the master - branch needs to be updated to 0.8.0.dev0. `Example PR for minor release` - - New micro release (e.g., 0.7.1): Increment the ``dev`` number, such that the number - after ``dev`` equals the micro version. For example, if the version of the new - release is 0.7.1, the master branch needs to be updated to 0.8.0.dev1. +11. **Update version numbers throughout codebase:** Suppose we just released + 0.7.1. The previous release version number (in this case 0.7.0) and the + previous dev version number (in this case 0.8.0.dev0) appear in many places + throughout the code base including the installation documentation, the + example autoscaler config files, and the testing scripts. Search for all of + the occurrences of these version numbers and update them to use the new + release and dev version numbers. **NOTE:** Not all of the version numbers + should be replaced. For example, ``0.7.0`` appears in this file but should + not be updated. -12. **Update version numbers throughout codebase:** Suppose we just released 0.7.1. The - previous release version number (in this case 0.7.0) and the previous dev version number - (in this case 0.8.0.dev0) appear in many places throughout the code base including - the installation documentation, the example autoscaler config files, and the testing - scripts. Search for all of the occurrences of these version numbers and update them to - use the new release and dev version numbers. +12. **Improve the release process:** Find some way to improve the release + process so that whoever manages the release next will have an easier time. -.. _documentation: https://ray.readthedocs.io/en/latest/installation.html#trying-snapshots-from-master -.. _`documentation for building wheels`: https://github.com/ray-project/ray/blob/master/python/README-building-wheels.md -.. _`ci/stress_tests/run_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_stress_tests.sh -.. _`ci/stress_tests/run_application_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_application_stress_tests.sh .. _`this example`: https://github.com/ray-project/ray/pull/4226 .. _`GitHub website`: https://github.com/ray-project/ray/releases -.. _`Example PR for minor release`: https://github.com/ray-project/ray/pull/4845 diff --git a/doc/source/conf.py b/doc/source/conf.py index b0ae3416d4ab..98fb3e0d02dd 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -29,7 +29,7 @@ "ray.core.generated.EntryType", "ray.core.generated.ErrorTableData", "ray.core.generated.ErrorType", - "ray.core.generated.GcsTableEntry", + "ray.core.generated.GcsEntry", "ray.core.generated.HeartbeatBatchTableData", "ray.core.generated.HeartbeatTableData", "ray.core.generated.Language", diff --git a/doc/source/development.rst b/doc/source/development.rst index 1fdc65fa35cf..ecbed6c31f9e 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -29,8 +29,11 @@ recompile much more quickly by doing .. code-block:: shell - cd ray/build - make -j8 + cd ray + bazel build //:ray_pkg + +This command is not enough to recompile all C++ unit tests. To do so, see +`Testing locally`_. Debugging --------- @@ -144,6 +147,14 @@ When running tests, usually only the first test failure matters. A single test failure often triggers the failure of subsequent tests in the same script. +To compile and run all C++ tests, you can run: + +.. code-block:: shell + + cd ray + bazel test $(bazel query 'kind(cc_test, ...)') + + Linting ------- diff --git a/doc/source/example-a3c.rst b/doc/source/example-a3c.rst index 3037b2b6b132..f8a8bfb4c1f3 100644 --- a/doc/source/example-a3c.rst +++ b/doc/source/example-a3c.rst @@ -127,7 +127,7 @@ global model parameters. The main training script looks like the following. obs = 0 # Start simulations on actors - agents = [Runner(env_name, i) for i in range(num_workers)] + agents = [Runner.remote(env_name, i) for i in range(num_workers)] # Start gradient calculation tasks on each actor parameters = policy.get_weights() diff --git a/doc/source/installation.rst b/doc/source/installation.rst index ad92cb347e83..b7cb27c831b6 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -33,14 +33,14 @@ Here are links to the latest wheels (which are built off of master). To install =================== =================== -.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-manylinux1_x86_64.whl -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-macosx_10_6_intel.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp37-cp37m-manylinux1_x86_64.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp37-cp37m-macosx_10_6_intel.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index b7b3ff823774..4b00f5636540 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -346,6 +346,37 @@ In PPO we run ``setup_mixins`` before the loss function is called (i.e., ``befor Finally, note that you do not have to use ``build_tf_policy`` to define a TensorFlow policy. You can alternatively subclass ``Policy``, ``TFPolicy``, or ``DynamicTFPolicy`` as convenient. +Building Policies in TensorFlow Eager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +While RLlib runs all TF operations in graph mode, you can still leverage TensorFlow eager using `tf.py_function `__. However, note that eager and non-eager tensors cannot be mixed within the ``py_function``. Here's an example of embedding eager execution within a policy loss function: + +.. code-block:: python + + def eager_loss(policy, batch_tensors): + """Example of using embedded eager execution in a custom loss. + + Here `compute_penalty` prints the actions and rewards for debugging, and + also computes a (dummy) penalty term to add to the loss. + """ + + def compute_penalty(actions, rewards): + penalty = tf.reduce_mean(tf.cast(actions, tf.float32)) + if random.random() > 0.9: + print("The eagerly computed penalty is", penalty, actions, rewards) + return penalty + + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + penalty = tf.py_function( + compute_penalty, [actions, rewards], Tout=tf.float32) + + return penalty - tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + +You can find a runnable file for the above eager execution example `here `__. + +There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms. + Building Policies in PyTorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib-examples.rst b/doc/source/rllib-examples.rst index 13bfdc68bfc1..604abf394de1 100644 --- a/doc/source/rllib-examples.rst +++ b/doc/source/rllib-examples.rst @@ -38,6 +38,8 @@ Custom Envs and Models Example of adding batch norm layers to a custom model. - `Parametric actions `__: Example of how to handle variable-length or parametric action spaces. +- `Eager execution `__: + Example of how to leverage TensorFlow eager to simplify debugging and design of custom models and policies. Serving and Offline ------------------- diff --git a/doc/source/rllib-package-ref.rst b/doc/source/rllib-package-ref.rst index db4b2dbfe0eb..6a4e6aed43f8 100644 --- a/doc/source/rllib-package-ref.rst +++ b/doc/source/rllib-package-ref.rst @@ -1,25 +1,11 @@ RLlib Package Reference ======================= -ray.rllib.agents +ray.rllib.policy ---------------- -.. automodule:: ray.rllib.agents +.. automodule:: ray.rllib.policy :members: - -.. autoclass:: ray.rllib.agents.a3c.A2CTrainer -.. autoclass:: ray.rllib.agents.a3c.A3CTrainer -.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGTrainer -.. autoclass:: ray.rllib.agents.ddpg.DDPGTrainer -.. autoclass:: ray.rllib.agents.dqn.ApexTrainer -.. autoclass:: ray.rllib.agents.dqn.DQNTrainer -.. autoclass:: ray.rllib.agents.es.ESTrainer -.. autoclass:: ray.rllib.agents.pg.PGTrainer -.. autoclass:: ray.rllib.agents.impala.ImpalaTrainer -.. autoclass:: ray.rllib.agents.ppo.APPOTrainer -.. autoclass:: ray.rllib.agents.ppo.PPOTrainer -.. autoclass:: ray.rllib.agents.marwil.MARWILTrainer - ray.rllib.env ------------- diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 824ef4c3dd88..9c365f8fb427 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -367,6 +367,13 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res openaigym.video.0.31403.video000000.meta.json openaigym.video.0.31403.video000000.mp4 +TensorFlow Eager +~~~~~~~~~~~~~~~~ + +While RLlib uses TF graph mode for all computations, you can still leverage TF eager to inspect the intermediate state of computations using `tf.py_function `__. Here's an example of using eager mode in `a custom RLlib model and loss `__. + +There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms. + Episode Traces ~~~~~~~~~~~~~~ diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index f0571b23c20e..d0d9d715aa7c 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -101,6 +101,8 @@ Concepts and Building Custom Algorithms - `Building Policies in TensorFlow `__ + - `Building Policies in TensorFlow Eager `__ + - `Building Policies in PyTorch `__ - `Extending Existing Policies `__ diff --git a/doc/source/tune-contrib.rst b/doc/source/tune-contrib.rst index f945ee679353..4774791e333d 100644 --- a/doc/source/tune-contrib.rst +++ b/doc/source/tune-contrib.rst @@ -15,10 +15,9 @@ We welcome (and encourage!) all forms of contributions to Tune, including and no Setting up a development environment ------------------------------------ -If you have Ray installed via pip (``pip install -U ray``), you can develop Tune locally without needing to compile Ray. +If you have Ray installed via pip (``pip install -U [link to wheel]`` - you can find the link to the latest wheel `here `__), you can develop Tune locally without needing to compile Ray. - -First, you will need your own [fork](https://help.github.com/en/articles/fork-a-repo) to work on the code. Press the Fork button on the `ray project page `__. +First, you will need your own `fork `__ to work on the code. Press the Fork button on the `ray project page `__. Then, clone the project to your machine and connect your repository to the upstream (main project) ray repository. .. code-block:: shell @@ -28,10 +27,16 @@ Then, clone the project to your machine and connect your repository to the upstr git remote add upstream https://github.com/ray-project/ray.git +Before continuing, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master `__ and have the latest `wheel `__ installed.) + Then, run `[path to ray directory]/python/ray/setup-dev.py` `(also here on Github) `__ script. This sets up links between the ``tune`` dir (among other directories) in your local repo and the one bundled with the ``ray`` package. -When using this script, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master `__ and have the latest `wheel `__ installed.) +As a last step make sure to install all packages required for development of tune. This can be done by running: + +.. code-block:: shell + + pip install -r [path to ray directory]/python/ray/tune/requirements-dev.txt What can I work on? @@ -89,7 +94,7 @@ burden and speedup review process. Documentation should be documented in `Google style `__ format. We also have tests for code formatting and linting that need to pass before merge. -Install `yapf==0.23, flake8, flake8-quotes`. You can run the following locally: +Install `yapf==0.23, flake8, flake8-quotes` (these are also in the `requirements-dev.txt` found in ``python/ray/tune``). You can run the following locally: .. code-block:: shell diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 664370eb0479..1d174ed72f92 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index b0cf426c1b1d..6e098d5218f6 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open diff --git a/java/BUILD.bazel b/java/BUILD.bazel index f86df8d40f96..80ccabccfc12 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -94,11 +94,14 @@ define_java_module( ":org_ray_ray_api", ":org_ray_ray_runtime", "@plasma//:org_apache_arrow_arrow_plasma", + "@maven//:com_google_guava_guava", + "@maven//:com_sun_xml_bind_jaxb_core", + "@maven//:com_sun_xml_bind_jaxb_impl", + "@maven//:commons_io_commons_io", + "@maven//:javax_xml_bind_jaxb_api", "@maven//:org_apache_commons_commons_lang3", "@maven//:org_slf4j_slf4j_api", "@maven//:org_testng_testng", - "@maven//:com_google_guava_guava", - "@maven//:commons_io_commons_io", ], ) @@ -160,7 +163,7 @@ flatbuffers_generated_files = [ "ErrorTableData.java", "ErrorType.java", "FunctionTableData.java", - "GcsTableEntry.java", + "GcsEntry.java", "HeartbeatBatchTableData.java", "HeartbeatTableData.java", "Language.java", diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index 3c5e1e3a3619..e08955d5a93e 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -41,13 +41,14 @@ public ByteBuffer toByteBuffer() { */ public boolean isNil() { if (isNilCache == null) { - isNilCache = true; + boolean localIsNil = true; for (int i = 0; i < size(); ++i) { if (id[i] != (byte) 0xff) { - isNilCache = false; + localIsNil = false; break; } } + isNilCache = localIsNil; } return isNilCache; } diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index e4f54f0094c4..d1e92f7bb9e9 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -1,5 +1,6 @@ package org.ray.api.options; +import java.util.HashMap; import java.util.Map; /** @@ -12,19 +13,32 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - public ActorCreationOptions() { - super(); - this.maxReconstructions = NO_RECONSTRUCTION; - } - - public ActorCreationOptions(Map resources) { + private ActorCreationOptions(Map resources, int maxReconstructions) { super(resources); - this.maxReconstructions = NO_RECONSTRUCTION; + this.maxReconstructions = maxReconstructions; } + /** + * The inner class for building ActorCreationOptions. + */ + public static class Builder { - public ActorCreationOptions(Map resources, int maxReconstructions) { - super(resources); - this.maxReconstructions = maxReconstructions; + private Map resources = new HashMap<>(); + private int maxReconstructions = NO_RECONSTRUCTION; + + public Builder setResources(Map resources) { + this.resources = resources; + return this; + } + + public Builder setMaxReconstructions(int maxReconstructions) { + this.maxReconstructions = maxReconstructions; + return this; + } + + public ActorCreationOptions createActorCreationOptions() { + return new ActorCreationOptions(resources, maxReconstructions); + } } + } diff --git a/java/api/src/main/java/org/ray/api/options/CallOptions.java b/java/api/src/main/java/org/ray/api/options/CallOptions.java index 84adfc122e04..1e5b61bf16d3 100644 --- a/java/api/src/main/java/org/ray/api/options/CallOptions.java +++ b/java/api/src/main/java/org/ray/api/options/CallOptions.java @@ -1,5 +1,6 @@ package org.ray.api.options; +import java.util.HashMap; import java.util.Map; /** @@ -7,12 +8,24 @@ */ public class CallOptions extends BaseTaskOptions { - public CallOptions() { - super(); - } - - public CallOptions(Map resources) { + private CallOptions(Map resources) { super(resources); } + /** + * This inner class for building CallOptions. + */ + public static class Builder { + + private Map resources = new HashMap<>(); + + public Builder setResources(Map resources) { + this.resources = resources; + return this; + } + + public CallOptions createCallOptions() { + return new CallOptions(resources); + } + } } diff --git a/java/dependencies.bzl b/java/dependencies.bzl index d0178ba0f8f4..7c716166d399 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -11,7 +11,7 @@ def gen_java_deps(): "com.sun.xml.bind:jaxb-impl:2.3.0", "com.typesafe:config:1.3.2", "commons-io:commons-io:2.5", - "de.ruedigermoeller:fst:2.47", + "de.ruedigermoeller:fst:2.57", "javax.xml.bind:jaxb-api:2.3.0", "org.apache.commons:commons-lang3:3.4", "org.ow2.asm:asm:6.0", diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 1ce51971c03e..c75e2eeef13f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -54,7 +54,7 @@ de.ruedigermoeller fst - 2.47 + 2.57 org.apache.commons diff --git a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java index 2938478d22e8..f1f26d40874e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java @@ -20,7 +20,9 @@ public class RayPyActorImpl extends RayActorImpl implements RayPyActor { */ private String className; - private RayPyActorImpl() {} + // Note that this empty constructor must be public + // since it'll be needed when deserializing. + public RayPyActorImpl() {} public RayPyActorImpl(UniqueId id, String moduleName, String className) { super(id); diff --git a/java/test/pom.xml b/java/test/pom.xml index 10f7ea4b3313..6a3a31d2032e 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -32,11 +32,26 @@ guava 27.0.1-jre + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + commons-io commons-io 2.5 + + javax.xml.bind + jaxb-api + 2.3.0 + org.apache.commons commons-lang3 diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index e575daa84f13..149c87f55931 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -3,7 +3,6 @@ import static org.ray.runtime.util.SystemUtil.pid; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.concurrent.TimeUnit; import org.ray.api.Checkpointable; @@ -47,7 +46,8 @@ public int getPid() { @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); - ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1); + ActorCreationOptions options = + new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(Counter::new, options); // Call increase 3 times. for (int i = 0; i < 3; i++) { @@ -127,8 +127,8 @@ public void checkpointExpired(UniqueId actorId, UniqueId checkpointId) { @Test public void testActorCheckpointing() throws IOException, InterruptedException { TestUtils.skipTestUnderSingleProcess(); - - ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1); + ActorCreationOptions options = + new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(CheckpointableCounter::new, options); // Call increase 3 times. for (int i = 0; i < 3; i++) { diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index ffda0732287e..79b3eba0ed13 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -23,7 +23,8 @@ public static String sayHi() { @Test public void testSetResource() { TestUtils.skipTestUnderSingleProcess(); - CallOptions op1 = new CallOptions(ImmutableMap.of("A", 10.0)); + CallOptions op1 = + new CallOptions.Builder().setResources(ImmutableMap.of("A", 10.0)).createCallOptions(); RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); WaitResult result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 0); diff --git a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java index feb07fe2cd42..04883bdf8673 100644 --- a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java +++ b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java @@ -33,4 +33,5 @@ public void testHelloWorld() { String helloWorld = Ray.call(HelloWorldTest::merge, hello, world).get(); Assert.assertEquals("hello,world!", helloWorld); } + } diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java new file mode 100644 index 000000000000..33283abc7a36 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -0,0 +1,23 @@ +package org.ray.api.test; + +import org.ray.api.RayPyActor; +import org.ray.api.id.UniqueId; +import org.ray.runtime.RayPyActorImpl; +import org.ray.runtime.util.Serializer; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RaySerializerTest { + + @Test + public void testSerializePyActor() { + final UniqueId pyActorId = UniqueId.randomId(); + RayPyActor pyActor = new RayPyActorImpl(pyActorId, "test", "RaySerializerTest"); + byte[] bytes = Serializer.encode(pyActor); + RayPyActor result = Serializer.decode(bytes); + Assert.assertEquals(result.getId(), pyActorId); + Assert.assertEquals(result.getModuleName(), "test"); + Assert.assertEquals(result.getClassName(), "RaySerializerTest"); + } + +} diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index c3d0e4152e5a..dca559764b87 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -46,14 +46,16 @@ public Integer echo(Integer number) { @Test public void testMethods() { TestUtils.skipTestUnderSingleProcess(); - CallOptions callOptions1 = new CallOptions(ImmutableMap.of("CPU", 4.0)); + CallOptions callOptions1 = + new CallOptions.Builder().setResources(ImmutableMap.of("CPU", 4.0)).createCallOptions(); // This is a case that can satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". RayObject result1 = Ray.call(ResourcesManagementTest::echo, 100, callOptions1); Assert.assertEquals(100, (int) result1.get()); - CallOptions callOptions2 = new CallOptions(ImmutableMap.of("CPU", 4.0)); + CallOptions callOptions2 = + new CallOptions.Builder().setResources(ImmutableMap.of("CPU", 4.0)).createCallOptions(); // This is a case that can't satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". @@ -64,7 +66,8 @@ public void testMethods() { Assert.assertEquals(0, waitResult.getUnready().size()); try { - CallOptions callOptions3 = new CallOptions(ImmutableMap.of("CPU", 0.0)); + CallOptions callOptions3 = + new CallOptions.Builder().setResources(ImmutableMap.of("CPU", 0.0)).createCallOptions(); Assert.fail(); } catch (RuntimeException e) { // We should receive a RuntimeException indicates that we should not @@ -76,9 +79,8 @@ public void testMethods() { public void testActors() { TestUtils.skipTestUnderSingleProcess(); - ActorCreationOptions actorCreationOptions1 = - new ActorCreationOptions(ImmutableMap.of("CPU", 2.0)); - + ActorCreationOptions actorCreationOptions1 = new ActorCreationOptions.Builder() + .setResources(ImmutableMap.of("CPU", 2.0)).createActorCreationOptions(); // This is a case that can satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". RayActor echo1 = Ray.createActor(Echo::new, actorCreationOptions1); @@ -87,8 +89,8 @@ public void testActors() { // This is a case that can't satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". - ActorCreationOptions actorCreationOptions2 = - new ActorCreationOptions(ImmutableMap.of("CPU", 8.0)); + ActorCreationOptions actorCreationOptions2 = new ActorCreationOptions.Builder() + .setResources(ImmutableMap.of("CPU", 8.0)).createActorCreationOptions(); RayActor echo2 = Ray.createActor(Echo::new, actorCreationOptions2); diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 421b1c6838ac..03792e5eb48a 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -96,7 +96,7 @@ from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. -__version__ = "0.7.1" +__version__ = "0.8.0.dev1" __all__ = [ "global_state", diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index 7399450aeedb..b3ecd22e7d5d 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -113,9 +113,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/aws/example-gpu-docker.yaml b/python/ray/autoscaler/aws/example-gpu-docker.yaml index 79fdc055b091..b63030a48344 100644 --- a/python/ray/autoscaler/aws/example-gpu-docker.yaml +++ b/python/ray/autoscaler/aws/example-gpu-docker.yaml @@ -105,9 +105,9 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index 4ab2093dd865..c307f1b10103 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -127,9 +127,9 @@ setup_commands: && echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.profile # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. diff --git a/python/ray/autoscaler/gcp/example-gpu-docker.yaml b/python/ray/autoscaler/gcp/example-gpu-docker.yaml index 75e0497094cb..5bb5eb9fb980 100644 --- a/python/ray/autoscaler/gcp/example-gpu-docker.yaml +++ b/python/ray/autoscaler/gcp/example-gpu-docker.yaml @@ -140,9 +140,9 @@ setup_commands: # - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index c86750fe399d..d42bf041ac8c 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -165,9 +165,10 @@ def wait_for_ssh(self, deadline): logger.debug("NodeUpdater: " "{}: Waiting for SSH...".format(self.node_id)) - with open("/dev/null", "w") as redirect: - self.ssh_cmd( - "uptime", connect_timeout=5, redirect=redirect) + # Setting redirect=False allows the user to see errors like + # unix_listener: path "/tmp/rkn_ray_ssh_sockets/..." too long + # for Unix domain socket. + self.ssh_cmd("uptime", connect_timeout=5, redirect=False) return True diff --git a/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py b/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py new file mode 100644 index 000000000000..160544633353 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py @@ -0,0 +1,131 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import torch.distributed as dist +import torch.utils.data + +from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner + +logger = logging.getLogger(__name__) + + +class DistributedPyTorchRunner(PyTorchRunner): + """Manages a distributed PyTorch model replica.""" + + def __init__(self, + model_creator, + data_creator, + optimizer_creator, + config=None, + batch_size=16, + backend="gloo"): + """Initializes the runner. + + Args: + model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. + data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + see pytorch_trainer.py. + config (dict): see pytorch_trainer.py. + batch_size (int): batch size used by one replica for an update. + backend (string): see pytorch_trainer.py. + """ + super(DistributedPyTorchRunner, self).__init__( + model_creator, data_creator, optimizer_creator, config, batch_size) + self.backend = backend + + def setup(self, url, world_rank, world_size): + """Connects to the distributed PyTorch backend and initializes the model. + + Args: + url (str): the URL used to connect to distributed PyTorch. + world_rank (int): the index of the runner. + world_size (int): the total number of runners. + """ + self._setup_distributed_pytorch(url, world_rank, world_size) + self._setup_training() + + def _setup_distributed_pytorch(self, url, world_rank, world_size): + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + with self._timers["setup_proc"]: + self.world_rank = world_rank + logger.debug( + "Connecting to {} world_rank: {} world_size: {}".format( + url, world_rank, world_size)) + logger.debug("using {}".format(self.backend)) + dist.init_process_group( + backend=self.backend, + init_method=url, + rank=world_rank, + world_size=world_size) + + def _setup_training(self): + logger.debug("Creating model") + self.model = self.model_creator(self.config) + if torch.cuda.is_available(): + self.model = torch.nn.parallel.DistributedDataParallel( + self.model.cuda()) + else: + self.model = torch.nn.parallel.DistributedDataParallelCPU( + self.model) + + logger.debug("Creating optimizer") + self.criterion, self.optimizer = self.optimizer_creator( + self.model, self.config) + if torch.cuda.is_available(): + self.criterion = self.criterion.cuda() + + logger.debug("Creating dataset") + self.training_set, self.validation_set = self.data_creator(self.config) + + # TODO: make num_workers configurable + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_set) + self.train_loader = torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=(self.train_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.train_sampler) + + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.validation_set)) + self.validation_loader = torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=(self.validation_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.validation_sampler) + + def step(self): + """Runs a training epoch and updates the model parameters.""" + logger.debug("Starting step") + self.train_sampler.set_epoch(self.epoch) + return super(DistributedPyTorchRunner, self).step() + + def get_state(self): + """Returns the state of the runner.""" + return { + "epoch": self.epoch, + "model": self.model.module.state_dict(), + "optimizer": self.optimizer.state_dict(), + "stats": self.stats() + } + + def set_state(self, state): + """Sets the state of the model.""" + # TODO: restore timer stats + self.model.module.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.epoch = state["stats"]["epoch"] + + def shutdown(self): + """Attempts to shut down the worker.""" + super(DistributedPyTorchRunner, self).shutdown() + dist.destroy_process_group() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_runner.py b/python/ray/experimental/sgd/pytorch/pytorch_runner.py index 5fe4ba1009f9..1663b2c64f0e 100644 --- a/python/ray/experimental/sgd/pytorch/pytorch_runner.py +++ b/python/ray/experimental/sgd/pytorch/pytorch_runner.py @@ -3,9 +3,7 @@ from __future__ import print_function import logging -import os import torch -import torch.distributed as dist import torch.utils.data import ray @@ -15,28 +13,23 @@ class PyTorchRunner(object): - """Manages a distributed PyTorch model replica""" + """Manages a PyTorch model for training.""" def __init__(self, model_creator, data_creator, optimizer_creator, config=None, - batch_size=16, - backend="gloo"): + batch_size=16): """Initializes the runner. Args: - model_creator (dict -> torch.nn.Module): creates the model using - the config. - data_creator (dict -> Dataset, Dataset): creates the training and - validation data sets using the config. + model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. + data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. optimizer_creator (torch.nn.Module, dict -> loss, optimizer): - creates the loss and optimizer using the model and the config. - config (dict): configuration passed to 'model_creator', - 'data_creator', and 'optimizer_creator'. - batch_size (int): batch size used in an update. - backend (string): backend used by distributed PyTorch. + see pytorch_trainer.py. + config (dict): see pytorch_trainer.py. + batch_size (int): see pytorch_trainer.py. """ self.model_creator = model_creator @@ -44,7 +37,6 @@ def __init__(self, self.optimizer_creator = optimizer_creator self.config = {} if config is None else config self.batch_size = batch_size - self.backend = backend self.verbose = True self.epoch = 0 @@ -56,82 +48,45 @@ def __init__(self, ] } - def setup(self, url, world_rank, world_size): - """Connects to the distributed PyTorch backend and initializes the model. - - Args: - url (str): the URL used to connect to distributed PyTorch. - world_rank (int): the index of the runner. - world_size (int): the total number of runners. - """ - self._setup_distributed_pytorch(url, world_rank, world_size) - self._setup_training() - - def _setup_distributed_pytorch(self, url, world_rank, world_size): - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - with self._timers["setup_proc"]: - self.world_rank = world_rank - logger.debug( - "Connecting to {} world_rank: {} world_size: {}".format( - url, world_rank, world_size)) - logger.debug("using {}".format(self.backend)) - dist.init_process_group( - backend=self.backend, - init_method=url, - rank=world_rank, - world_size=world_size) - - def _setup_training(self): + def setup(self): + """Initializes the model.""" logger.debug("Creating model") self.model = self.model_creator(self.config) if torch.cuda.is_available(): - self.model = torch.nn.parallel.DistributedDataParallel( - self.model.cuda()) - else: - self.model = torch.nn.parallel.DistributedDataParallelCPU( - self.model) + self.model = self.model.cuda() logger.debug("Creating optimizer") self.criterion, self.optimizer = self.optimizer_creator( self.model, self.config) - if torch.cuda.is_available(): self.criterion = self.criterion.cuda() logger.debug("Creating dataset") self.training_set, self.validation_set = self.data_creator(self.config) - - # TODO: make num_workers configurable - self.train_sampler = torch.utils.data.distributed.DistributedSampler( - self.training_set) self.train_loader = torch.utils.data.DataLoader( self.training_set, batch_size=self.batch_size, - shuffle=(self.train_sampler is None), + shuffle=True, num_workers=2, - pin_memory=False, - sampler=self.train_sampler) + pin_memory=False) - self.validation_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.validation_set)) self.validation_loader = torch.utils.data.DataLoader( self.validation_set, batch_size=self.batch_size, - shuffle=(self.validation_sampler is None), + shuffle=True, num_workers=2, - pin_memory=False, - sampler=self.validation_sampler) + pin_memory=False) def get_node_ip(self): - """Returns the IP address of the current node""" + """Returns the IP address of the current node.""" return ray.services.get_node_ip_address() - def step(self): - """Runs a training epoch and updates the model parameters""" - logger.debug("Starting step") - self.train_sampler.set_epoch(self.epoch) + def find_free_port(self): + """Finds a free port on the current node.""" + return utils.find_free_port() + def step(self): + """Runs a training epoch and updates the model parameters.""" logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) with self._timers["training"]: train_stats = utils.train(self.train_loader, self.model, @@ -144,7 +99,7 @@ def step(self): return train_stats def validate(self): - """Evaluates the model on the validation data set""" + """Evaluates the model on the validation data set.""" with self._timers["validation"]: validation_stats = utils.validate(self.validation_loader, self.model, self.criterion) @@ -153,7 +108,7 @@ def validate(self): return validation_stats def stats(self): - """Returns a dictionary of statistics collected""" + """Returns a dictionary of statistics collected.""" stats = {"epoch": self.epoch} for k, t in self._timers.items(): stats[k + "_time_mean"] = t.mean @@ -162,7 +117,7 @@ def stats(self): return stats def get_state(self): - """Returns the state of the runner""" + """Returns the state of the runner.""" return { "epoch": self.epoch, "model": self.model.state_dict(), @@ -171,12 +126,20 @@ def get_state(self): } def set_state(self, state): - """Sets the state of the model""" + """Sets the state of the model.""" # TODO: restore timer stats self.model.load_state_dict(state["model"]) self.optimizer.load_state_dict(state["optimizer"]) self.epoch = state["stats"]["epoch"] def shutdown(self): - """Attempts to shut down the worker""" - dist.destroy_process_group() + """Attempts to shut down the worker.""" + del self.validation_loader + del self.validation_set + del self.train_loader + del self.training_set + del self.criterion + del self.optimizer + del self.model + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py index 073ad3d34042..0e0c5d8436a1 100644 --- a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py +++ b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py @@ -3,13 +3,15 @@ from __future__ import print_function import numpy as np -import sys import torch +import torch.distributed as dist import logging import ray from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner +from ray.experimental.sgd.pytorch.distributed_pytorch_runner import ( + DistributedPyTorchRunner) from ray.experimental.sgd.pytorch import utils logger = logging.getLogger(__name__) @@ -51,10 +53,11 @@ def __init__(self, """ # TODO: add support for mixed precision # TODO: add support for callbacks - if sys.platform == "darwin": - raise Exception( - ("Distributed PyTorch is not supported on macOS. For more " - "information, see " + if num_replicas > 1 and not dist.is_available(): + raise ValueError( + ("Distributed PyTorch is not supported on macOS. " + "To run without distributed PyTorch, set 'num_replicas=1'. " + "For more information, see " "https://github.com/pytorch/examples/issues/467.")) self.model_creator = model_creator @@ -68,40 +71,55 @@ def __init__(self, if backend == "auto": backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo" - Runner = ray.remote( - num_cpus=resources_per_replica.num_cpus, - num_gpus=resources_per_replica.num_gpus, - resources=resources_per_replica.resources)(PyTorchRunner) - - batch_size_per_replica = batch_size // num_replicas - if batch_size % num_replicas > 0: - new_batch_size = batch_size_per_replica * num_replicas - logger.warn( - ("Changing batch size from {old_batch_size} to " - "{new_batch_size} to evenly distribute batches across " - "{num_replicas} replicas.").format( - old_batch_size=batch_size, - new_batch_size=new_batch_size, - num_replicas=num_replicas)) - - self.workers = [ - Runner.remote(model_creator, data_creator, optimizer_creator, - self.config, batch_size_per_replica, backend) - for i in range(num_replicas) - ] - - ip = ray.get(self.workers[0].get_node_ip.remote()) - port = utils.find_free_port() - address = "tcp://{ip}:{port}".format(ip=ip, port=port) - - # Get setup tasks in order to throw errors on failure - ray.get([ - worker.setup.remote(address, i, len(self.workers)) - for i, worker in enumerate(self.workers) - ]) + if num_replicas == 1: + # Generate actor class + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)(PyTorchRunner) + # Start workers + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size) + ] + # Get setup tasks in order to throw errors on failure + ray.get(self.workers[0].setup.remote()) + else: + # Geneate actor class + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)( + DistributedPyTorchRunner) + # Compute batch size per replica + batch_size_per_replica = batch_size // num_replicas + if batch_size % num_replicas > 0: + new_batch_size = batch_size_per_replica * num_replicas + logger.warn( + ("Changing batch size from {old_batch_size} to " + "{new_batch_size} to evenly distribute batches across " + "{num_replicas} replicas.").format( + old_batch_size=batch_size, + new_batch_size=new_batch_size, + num_replicas=num_replicas)) + # Start workers + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size_per_replica, backend) + for i in range(num_replicas) + ] + # Compute URL for initializing distributed PyTorch + ip = ray.get(self.workers[0].get_node_ip.remote()) + port = ray.get(self.workers[0].find_free_port.remote()) + address = "tcp://{ip}:{port}".format(ip=ip, port=port) + # Get setup tasks in order to throw errors on failure + ray.get([ + worker.setup.remote(address, i, len(self.workers)) + for i, worker in enumerate(self.workers) + ]) def train(self): - """Runs a training epoch""" + """Runs a training epoch.""" with self.optimizer_timer: worker_stats = ray.get([w.step.remote() for w in self.workers]) @@ -111,7 +129,7 @@ def train(self): return train_stats def validate(self): - """Evaluates the model on the validation data set""" + """Evaluates the model on the validation data set.""" worker_stats = ray.get([w.validate.remote() for w in self.workers]) validation_stats = worker_stats[0].copy() validation_stats["validation_loss"] = np.mean( @@ -119,32 +137,25 @@ def validate(self): return validation_stats def get_model(self): - """Returns the learned model""" + """Returns the learned model.""" model = self.model_creator(self.config) state = ray.get(self.workers[0].get_state.remote()) - - # Remove module. prefix added by distrbuted pytorch - state_dict = { - k.replace("module.", ""): v - for k, v in state["model"].items() - } - - model.load_state_dict(state_dict) + model.load_state_dict(state["model"]) return model def save(self, ckpt): - """Saves the model at the provided checkpoint""" + """Saves the model at the provided checkpoint.""" state = ray.get(self.workers[0].get_state.remote()) torch.save(state, ckpt) def restore(self, ckpt): - """Restores the model from the provided checkpoint""" + """Restores the model from the provided checkpoint.""" state = torch.load(ckpt) state_id = ray.put(state) ray.get([worker.set_state.remote(state_id) for worker in self.workers]) def shutdown(self): - """Shuts down workers and releases resources""" + """Shuts down workers and releases resources.""" for worker in self.workers: worker.shutdown.remote() worker.__ray_terminate__.remote() diff --git a/python/ray/experimental/sgd/pytorch/utils.py b/python/ray/experimental/sgd/pytorch/utils.py index f7c6e4abac97..5be26b331cfd 100644 --- a/python/ray/experimental/sgd/pytorch/utils.py +++ b/python/ray/experimental/sgd/pytorch/utils.py @@ -196,7 +196,7 @@ def find_free_port(): class AverageMeter(object): - """Computes and stores the average and current value""" + """Computes and stores the average and current value.""" def __init__(self): self.reset() diff --git a/python/ray/experimental/sgd/tests/test_pytorch.py b/python/ray/experimental/sgd/tests/test_pytorch.py index faff23f8a809..aa0596aa158c 100644 --- a/python/ray/experimental/sgd/tests/test_pytorch.py +++ b/python/ray/experimental/sgd/tests/test_pytorch.py @@ -4,9 +4,9 @@ import os import pytest -import sys import tempfile import torch +import torch.distributed as dist from ray.tests.conftest import ray_start_2_cpus # noqa: F401 from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources @@ -15,14 +15,14 @@ model_creator, optimizer_creator, data_creator) -@pytest.mark.skipif( # noqa: F811 - sys.platform == "darwin", reason="Doesn't work on macOS.") -def test_train(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize( # noqa: F811 + "num_replicas", [1, 2] if dist.is_available() else [1]) +def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["validation_loss"] @@ -37,14 +37,14 @@ def test_train(ray_start_2_cpus): # noqa: F811 assert validation_loss2 <= validation_loss1 -@pytest.mark.skipif( # noqa: F811 - sys.platform == "darwin", reason="Doesn't work on macOS.") -def test_save_and_restore(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize( # noqa: F811 + "num_replicas", [1, 2] if dist.is_available() else [1]) +def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811 trainer1 = PyTorchTrainer( model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) trainer1.train() @@ -59,7 +59,7 @@ def test_save_and_restore(ray_start_2_cpus): # noqa: F811 model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) trainer2.restore(filename) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 15eec6c81136..cadd197ec73f 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -9,7 +9,7 @@ from ray.core.generated.ClientTableData import ClientTableData from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsTableEntry import GcsTableEntry +from ray.core.generated.GcsEntry import GcsEntry from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData from ray.core.generated.Language import Language @@ -25,7 +25,7 @@ "ClientTableData", "DriverTableData", "ErrorTableData", - "GcsTableEntry", + "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", "Language", diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 09a154d7b548..c9e0424b3eb8 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,8 +101,7 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatBatchTableData. @@ -208,8 +207,7 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) driver_data = gcs_entries.Entries(0) message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( driver_data, 0) diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 44d2777a2900..9ff6994b8d42 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -4,6 +4,7 @@ import copy import logging +from functools import wraps from ray.function_manager import FunctionDescriptor import ray.signature @@ -74,15 +75,18 @@ def __init__(self, function, num_cpus, num_gpus, resources, self._last_driver_id_exported_for = None + # Override task.remote's signature and docstring + @wraps(function) + def _remote_proxy(*args, **kwargs): + return self._remote(args=args, kwargs=kwargs) + + self.remote = _remote_proxy + def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " "of running '{}()', try '{}.remote()'.".format( self._function_name, self._function_name)) - def remote(self, *args, **kwargs): - """This runs immediately when a remote function is called.""" - return self._remote(args=args, kwargs=kwargs) - def _submit(self, args=None, kwargs=None, diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy.py b/python/ray/rllib/agents/a3c/a3c_tf_policy.py index ed3676472850..d05f496a7945 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy.py @@ -41,8 +41,9 @@ def actor_critic_loss(policy, batch_tensors): policy.loss = A3CLoss( policy.action_dist, batch_tensors[SampleBatch.ACTIONS], batch_tensors[Postprocessing.ADVANTAGES], - batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf, - policy.config["vf_loss_coeff"], policy.config["entropy_coeff"]) + batch_tensors[Postprocessing.VALUE_TARGETS], + policy.convert_to_eager(policy.vf), policy.config["vf_loss_coeff"], + policy.config["entropy_coeff"]) return policy.loss.total_loss diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 5ea732f17508..e0731e87a809 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -2,15 +2,14 @@ from __future__ import division from __future__ import print_function +from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \ DEFAULT_CONFIG as DDPG_CONFIG -from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_DDPG_DEFAULT_CONFIG = merge_dicts( DDPG_CONFIG, # see also the options in ddpg.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DDPG_CONFIG["optimizer"], { "max_weight_sync_delay": 400, @@ -32,23 +31,7 @@ }, ) - -class ApexDDPGTrainer(DDPGTrainer): - """DDPG variant that uses the Ape-X distributed policy optimizer. - - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - - _name = "APEX_DDPG" - _default_config = APEX_DDPG_DEFAULT_CONFIG - - @override(DDPGTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 +ApexDDPGTrainer = DDPGTrainer.with_updates( + name="APEX_DDPG", + default_config=APEX_DDPG_DEFAULT_CONFIG, + **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index a9676335eb3f..a6b42f1ca927 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -3,9 +3,9 @@ from __future__ import print_function from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer, \ + update_worker_explorations from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy -from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule # yapf: disable @@ -97,6 +97,11 @@ # optimization on initial policy parameters. Note that this will be # disabled when the action noise scale is set to 0 (e.g during evaluation). "pure_exploration_steps": 1000, + # Extra configuration that disables exploration. + "evaluation_config": { + "exploration_fraction": 0, + "exploration_final_eps": 0, + }, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then @@ -108,6 +113,11 @@ "prioritized_replay_alpha": 0.6, # Beta parameter for sampling from prioritized replay buffer. "prioritized_replay_beta": 0.4, + # Fraction of entire training period over which the beta parameter is + # annealed + "beta_annealing_fraction": 0.2, + # Final value of beta + "final_prioritized_replay_beta": 0.4, # Epsilon to add to the TD errors when updating priorities. "prioritized_replay_eps": 1e-6, # Whether to LZ4 compress observations @@ -146,8 +156,6 @@ # to increase if your environment is particularly slow to sample, or if # you're using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -159,47 +167,56 @@ # yapf: enable -class DDPGTrainer(DQNTrainer): - """DDPG implementation in TensorFlow.""" - _name = "DDPG" - _default_config = DEFAULT_CONFIG - _policy = DDPGTFPolicy +def make_exploration_schedule(config, worker_index): + # Modification of DQN's schedule to take into account + # `exploration_ou_noise_scale` + if config["per_worker_exploration"]: + assert config["num_workers"] > 1, "This requires multiple workers" + if worker_index >= 0: + # FIXME: what do magic constants mean? (0.4, 7) + max_index = float(config["num_workers"] - 1) + exponent = 1 + worker_index / max_index * 7 + return ConstantSchedule(0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + elif config["exploration_should_anneal"]: + return LinearSchedule( + schedule_timesteps=int(config["exploration_fraction"] * + config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_scale"]) + else: + # *always* add exploration noise + return ConstantSchedule(1.0) + + +def setup_ddpg_exploration(trainer): + trainer.exploration0 = make_exploration_schedule(trainer.config, -1) + trainer.explorations = [ + make_exploration_schedule(trainer.config, i) + for i in range(trainer.config["num_workers"]) + ] - @override(DQNTrainer) - def _train(self): - pure_expl_steps = self.config["pure_exploration_steps"] - if pure_expl_steps: - # tell workers whether they should do pure exploration - only_explore = self.global_timestep < pure_expl_steps - self.workers.local_worker().foreach_trainable_policy( + +def add_pure_exploration_phase(trainer): + global_timestep = trainer.optimizer.num_steps_sampled + pure_expl_steps = trainer.config["pure_exploration_steps"] + if pure_expl_steps: + # tell workers whether they should do pure exploration + only_explore = global_timestep < pure_expl_steps + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.set_pure_exploration_phase(only_explore)) + for e in trainer.workers.remote_workers(): + e.foreach_trainable_policy.remote( lambda p, _: p.set_pure_exploration_phase(only_explore)) - for e in self.workers.remote_workers(): - e.foreach_trainable_policy.remote( - lambda p, _: p.set_pure_exploration_phase(only_explore)) - return super(DDPGTrainer, self)._train() - - @override(DQNTrainer) - def _make_exploration_schedule(self, worker_index): - # Override DQN's schedule to take into account - # `exploration_ou_noise_scale` - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - # FIXME: what do magic constants mean? (0.4, 7) - max_index = float(self.config["num_workers"] - 1) - exponent = 1 + worker_index / max_index * 7 - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - elif self.config["exploration_should_anneal"]: - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_scale"]) - else: - # *always* add exploration noise - return ConstantSchedule(1.0) + update_worker_explorations(trainer) + + +DDPGTrainer = GenericOffPolicyTrainer.with_updates( + name="DDPG", + default_config=DEFAULT_CONFIG, + default_policy=DDPGTFPolicy, + before_init=setup_ddpg_exploration, + before_train_step=add_pure_exploration_phase) diff --git a/python/ray/rllib/agents/ddpg/td3.py b/python/ray/rllib/agents/ddpg/td3.py index 714c39c6b2f8..ad3675294ce5 100644 --- a/python/ray/rllib/agents/ddpg/td3.py +++ b/python/ray/rllib/agents/ddpg/td3.py @@ -1,3 +1,9 @@ +"""A more stable successor to TD3. + +By default, this uses a near-identical configuration to that reported in the +TD3 paper. +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -36,7 +42,6 @@ "train_batch_size": 100, "use_huber": False, "target_network_update_freq": 0, - "optimizer_class": "SyncReplayOptimizer", "num_workers": 0, "num_gpus_per_worker": 0, "per_worker_exploration": False, @@ -48,10 +53,5 @@ }, ) - -class TD3Trainer(DDPGTrainer): - """A more stable successor to TD3. By default, this uses a near-identical - configuration to that reported in the TD3 paper.""" - - _name = "TD3" - _default_config = TD3_DEFAULT_CONFIG +TD3Trainer = DDPGTrainer.with_updates( + name="TD3", default_config=TD3_DEFAULT_CONFIG) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index 129839a27119..ab89256a6b95 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -3,15 +3,14 @@ from __future__ import print_function from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG +from ray.rllib.optimizers import AsyncReplayOptimizer from ray.rllib.utils import merge_dicts -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ APEX_DEFAULT_CONFIG = merge_dicts( DQN_CONFIG, # see also the options in dqn.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DQN_CONFIG["optimizer"], { "max_weight_sync_delay": 400, @@ -36,22 +35,50 @@ # yapf: enable -class ApexTrainer(DQNTrainer): - """DQN variant that uses the Ape-X distributed policy optimizer. +def defer_make_workers(trainer, env_creator, policy, config): + # Hack to workaround https://github.com/ray-project/ray/issues/2541 + # The workers will be creatd later, after the optimizer is created + return trainer._make_workers(env_creator, policy, config, 0) - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - _name = "APEX" - _default_config = APEX_DEFAULT_CONFIG +def make_async_optimizer(workers, config): + assert len(workers.remote_workers()) == 0 + extra_config = config["optimizer"].copy() + for key in [ + "prioritized_replay", "prioritized_replay_alpha", + "prioritized_replay_beta", "prioritized_replay_eps" + ]: + if key in config: + extra_config[key] = config[key] + opt = AsyncReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + train_batch_size=config["train_batch_size"], + sample_batch_size=config["sample_batch_size"], + **extra_config) + workers.add_workers(config["num_workers"]) + opt._set_workers(workers.remote_workers()) + return opt - @override(DQNTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 + +def update_target_based_on_num_steps_trained(trainer, fetches): + # Ape-X updates based on num steps trained, not sampled + if (trainer.optimizer.num_steps_trained - + trainer.state["last_target_update_ts"] > + trainer.config["target_network_update_freq"]): + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + trainer.state["last_target_update_ts"] = ( + trainer.optimizer.num_steps_trained) + trainer.state["num_target_updates"] += 1 + + +APEX_TRAINER_PROPERTIES = { + "make_workers": defer_make_workers, + "make_policy_optimizer": make_async_optimizer, + "after_optimizer_step": update_target_based_on_num_steps_trained, +} + +ApexTrainer = DQNTrainer.with_updates( + name="APEX", default_config=APEX_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 15379e3fb394..cc418907a0b9 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -3,27 +3,17 @@ from __future__ import print_function import logging -import time from ray import tune -from ray.rllib import optimizers -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy -from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.optimizers import SyncReplayOptimizer from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule logger = logging.getLogger(__name__) -OPTIMIZER_SHARED_CONFIGS = [ - "buffer_size", "prioritized_replay", "prioritized_replay_alpha", - "prioritized_replay_beta", "schedule_max_timesteps", - "beta_annealing_fraction", "final_prioritized_replay_beta", - "prioritized_replay_eps", "sample_batch_size", "train_batch_size", - "learning_starts" -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -53,7 +43,8 @@ # 1.0 to exploration_fraction over this number of timesteps scaled by # exploration_fraction "schedule_max_timesteps": 100000, - # Number of env steps to optimize for before returning + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. "timesteps_per_iteration": 1000, # Fraction of entire training period over which the exploration rate is # annealed @@ -70,6 +61,11 @@ # If True parameter space noise will be used for exploration # See https://blog.openai.com/better-exploration-with-parameter-noise/ "parameter_noise": False, + # Extra configuration that disables exploration. + "evaluation_config": { + "exploration_fraction": 0, + "exploration_final_eps": 0, + }, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then @@ -115,8 +111,6 @@ # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -128,202 +122,175 @@ # yapf: enable -class DQNTrainer(Trainer): - """DQN implementation in TensorFlow.""" - - _name = "DQN" - _default_config = DEFAULT_CONFIG - _policy = DQNTFPolicy - _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS - - @override(Trainer) - def _init(self, config, env_creator): - self._validate_config() - - # Update effective batch size to include n-step - adjusted_batch_size = max(config["sample_batch_size"], - config.get("n_step", 1)) - config["sample_batch_size"] = adjusted_batch_size - - self.exploration0 = self._make_exploration_schedule(-1) - self.explorations = [ - self._make_exploration_schedule(i) - for i in range(config["num_workers"]) - ] - - for k in self._optimizer_shared_configs: - if self._name != "DQN" and k in [ - "schedule_max_timesteps", "beta_annealing_fraction", - "final_prioritized_replay_beta" - ]: - # only Rainbow needs annealing prioritized_replay_beta - continue - if k not in config["optimizer"]: - config["optimizer"][k] = config[k] - - if config.get("parameter_noise", False): - if config["callbacks"]["on_episode_start"]: - start_callback = config["callbacks"]["on_episode_start"] - else: - start_callback = None - - def on_episode_start(info): - # as a callback function to sample and pose parameter space - # noise on the parameters of network - policies = info["policy"] - for pi in policies.values(): - pi.add_parameter_noise() - if start_callback: - start_callback(info) - - config["callbacks"]["on_episode_start"] = tune.function( - on_episode_start) - if config["callbacks"]["on_episode_end"]: - end_callback = config["callbacks"]["on_episode_end"] - else: - end_callback = None - - def on_episode_end(info): - # as a callback function to monitor the distance - # between noisy policy and original policy - policies = info["policy"] - episode = info["episode"] - episode.custom_metrics["policy_distance"] = policies[ - DEFAULT_POLICY_ID].pi_distance - if end_callback: - end_callback(info) - - config["callbacks"]["on_episode_end"] = tune.function( - on_episode_end) - - if config["optimizer_class"] != "AsyncReplayOptimizer": - self.workers = self._make_workers( - env_creator, - self._policy, - config, - num_workers=self.config["num_workers"]) - workers_needed = 0 +def make_optimizer(workers, config): + return SyncReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + prioritized_replay=config["prioritized_replay"], + prioritized_replay_alpha=config["prioritized_replay_alpha"], + prioritized_replay_beta=config["prioritized_replay_beta"], + schedule_max_timesteps=config["schedule_max_timesteps"], + beta_annealing_fraction=config["beta_annealing_fraction"], + final_prioritized_replay_beta=config["final_prioritized_replay_beta"], + prioritized_replay_eps=config["prioritized_replay_eps"], + train_batch_size=config["train_batch_size"], + sample_batch_size=config["sample_batch_size"], + **config["optimizer"]) + + +def check_config_and_setup_param_noise(config): + """Update the config based on settings. + + Rewrites sample_batch_size to take into account n_step truncation, and also + adds the necessary callbacks to support parameter space noise exploration. + """ + + # Update effective batch size to include n-step + adjusted_batch_size = max(config["sample_batch_size"], + config.get("n_step", 1)) + config["sample_batch_size"] = adjusted_batch_size + + if config.get("parameter_noise", False): + if config["batch_mode"] != "complete_episodes": + raise ValueError("Exploration with parameter space noise requires " + "batch_mode to be complete_episodes.") + if config.get("noisy", False): + raise ValueError( + "Exploration with parameter space noise and noisy network " + "cannot be used at the same time.") + if config["callbacks"]["on_episode_start"]: + start_callback = config["callbacks"]["on_episode_start"] + else: + start_callback = None + + def on_episode_start(info): + # as a callback function to sample and pose parameter space + # noise on the parameters of network + policies = info["policy"] + for pi in policies.values(): + pi.add_parameter_noise() + if start_callback: + start_callback(info) + + config["callbacks"]["on_episode_start"] = tune.function( + on_episode_start) + if config["callbacks"]["on_episode_end"]: + end_callback = config["callbacks"]["on_episode_end"] else: - # Hack to workaround https://github.com/ray-project/ray/issues/2541 - self.workers = self._make_workers( - env_creator, self._policy, config, num_workers=0) - workers_needed = self.config["num_workers"] - - self.optimizer = getattr(optimizers, config["optimizer_class"])( - self.workers, **config["optimizer"]) - - # Create the remote workers *after* the replay actors - if workers_needed > 0: - self.workers.add_workers(workers_needed) - self.optimizer._set_workers(self.workers.remote_workers()) - - self.last_target_update_ts = 0 - self.num_target_updates = 0 - - @override(Trainer) - def _train(self): - start_timestep = self.global_timestep - - # Update worker explorations - exp_vals = [self.exploration0.value(self.global_timestep)] - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.set_epsilon(exp_vals[0])) - for i, e in enumerate(self.workers.remote_workers()): - exp_val = self.explorations[i].value(self.global_timestep) - e.foreach_trainable_policy.remote( - lambda p, _: p.set_epsilon(exp_val)) - exp_vals.append(exp_val) - - # Do optimization steps - start = time.time() - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"] - ) or time.time() - start < self.config["min_iter_time_s"]: - self.optimizer.step() - self.update_target_if_needed() - - if self.config["per_worker_exploration"]: - # Only collect metrics from the third of workers with lowest eps - result = self.collect_metrics( - selected_workers=self.workers.remote_workers()[ - -len(self.workers.remote_workers()) // 3:]) + end_callback = None + + def on_episode_end(info): + # as a callback function to monitor the distance + # between noisy policy and original policy + policies = info["policy"] + episode = info["episode"] + episode.custom_metrics["policy_distance"] = policies[ + DEFAULT_POLICY_ID].pi_distance + if end_callback: + end_callback(info) + + config["callbacks"]["on_episode_end"] = tune.function(on_episode_end) + + +def get_initial_state(config): + return { + "last_target_update_ts": 0, + "num_target_updates": 0, + } + + +def make_exploration_schedule(config, worker_index): + # Use either a different `eps` per worker, or a linear schedule. + if config["per_worker_exploration"]: + assert config["num_workers"] > 1, \ + "This requires multiple workers" + if worker_index >= 0: + exponent = ( + 1 + worker_index / float(config["num_workers"] - 1) * 7) + return ConstantSchedule(0.4**exponent) else: - result = self.collect_metrics() - - result.update( - timesteps_this_iter=self.global_timestep - start_timestep, - info=dict({ - "min_exploration": min(exp_vals), - "max_exploration": max(exp_vals), - "num_target_updates": self.num_target_updates, - }, **self.optimizer.stats())) - - return result - - def update_target_if_needed(self): - if self.global_timestep - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.global_timestep - self.num_target_updates += 1 - - @property - def global_timestep(self): - return self.optimizer.num_steps_sampled - - def _evaluate(self): - logger.info("Evaluating current policy for {} episodes".format( - self.config["evaluation_num_episodes"])) - self.evaluation_workers.local_worker().restore( - self.workers.local_worker().save()) - self.evaluation_workers.local_worker().foreach_policy( - lambda p, _: p.set_epsilon(0)) - for _ in range(self.config["evaluation_num_episodes"]): - self.evaluation_workers.local_worker().sample() - metrics = collect_metrics(self.evaluation_workers.local_worker()) - return {"evaluation": metrics} - - def _make_exploration_schedule(self, worker_index): - # Use either a different `eps` per worker, or a linear schedule. - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - exponent = ( - 1 + - worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_eps"]) - - def __getstate__(self): - state = Trainer.__getstate__(self) - state.update({ - "num_target_updates": self.num_target_updates, - "last_target_update_ts": self.last_target_update_ts, - }) - return state - - def __setstate__(self, state): - Trainer.__setstate__(self, state) - self.num_target_updates = state["num_target_updates"] - self.last_target_update_ts = state["last_target_update_ts"] - - def _validate_config(self): - if self.config.get("parameter_noise", False): - if self.config["batch_mode"] != "complete_episodes": - raise ValueError( - "Exploration with parameter space noise requires " - "batch_mode to be complete_episodes.") - if self.config.get("noisy", False): - raise ValueError( - "Exploration with parameter space noise and noisy network " - "cannot be used at the same time.") + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + return LinearSchedule( + schedule_timesteps=int( + config["exploration_fraction"] * config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_eps"]) + + +def setup_exploration(trainer): + trainer.exploration0 = make_exploration_schedule(trainer.config, -1) + trainer.explorations = [ + make_exploration_schedule(trainer.config, i) + for i in range(trainer.config["num_workers"]) + ] + + +def update_worker_explorations(trainer): + global_timestep = trainer.optimizer.num_steps_sampled + exp_vals = [trainer.exploration0.value(global_timestep)] + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.set_epsilon(exp_vals[0])) + for i, e in enumerate(trainer.workers.remote_workers()): + exp_val = trainer.explorations[i].value(global_timestep) + e.foreach_trainable_policy.remote(lambda p, _: p.set_epsilon(exp_val)) + exp_vals.append(exp_val) + trainer.train_start_timestep = global_timestep + trainer.cur_exp_vals = exp_vals + + +def add_trainer_metrics(trainer, result): + global_timestep = trainer.optimizer.num_steps_sampled + result.update( + timesteps_this_iter=global_timestep - trainer.train_start_timestep, + info=dict({ + "min_exploration": min(trainer.cur_exp_vals), + "max_exploration": max(trainer.cur_exp_vals), + "num_target_updates": trainer.state["num_target_updates"], + }, **trainer.optimizer.stats())) + + +def update_target_if_needed(trainer, fetches): + global_timestep = trainer.optimizer.num_steps_sampled + if global_timestep - trainer.state["last_target_update_ts"] > \ + trainer.config["target_network_update_freq"]: + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + trainer.state["last_target_update_ts"] = global_timestep + trainer.state["num_target_updates"] += 1 + + +def collect_metrics(trainer): + if trainer.config["per_worker_exploration"]: + # Only collect metrics from the third of workers with lowest eps + result = trainer.collect_metrics( + selected_workers=trainer.workers.remote_workers()[ + -len(trainer.workers.remote_workers()) // 3:]) + else: + result = trainer.collect_metrics() + return result + + +def disable_exploration(trainer): + trainer.evaluation_workers.local_worker().foreach_policy( + lambda p, _: p.set_epsilon(0)) + + +GenericOffPolicyTrainer = build_trainer( + name="GenericOffPolicyAlgorithm", + default_policy=None, + default_config=DEFAULT_CONFIG, + validate_config=check_config_and_setup_param_noise, + get_initial_state=get_initial_state, + make_policy_optimizer=make_optimizer, + before_init=setup_exploration, + before_train_step=update_worker_explorations, + after_optimizer_step=update_target_if_needed, + after_train_result=add_trainer_metrics, + collect_metrics_fn=collect_metrics, + before_evaluate_fn=disable_exploration) + +DQNTrainer = GenericOffPolicyTrainer.with_updates( + name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index e025a4817f8f..b9699888bfaf 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -2,33 +2,16 @@ from __future__ import division from __future__ import print_function -import time - from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.utils.annotations import override from ray.tune.trainable import Trainable from ray.tune.trial import Resources -OPTIMIZER_SHARED_CONFIGS = [ - "lr", - "num_envs_per_worker", - "num_gpus", - "sample_batch_size", - "train_batch_size", - "replay_buffer_num_slots", - "replay_proportion", - "num_data_loader_buffers", - "max_sample_requests_in_flight_per_worker", - "broadcast_interval", - "num_sgd_iter", - "minibatch_buffer_size", - "num_aggregation_workers", -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -100,37 +83,57 @@ # yapf: enable -class ImpalaTrainer(Trainer): - """IMPALA implementation using DeepMind's V-trace.""" - - _name = "IMPALA" - _default_config = DEFAULT_CONFIG - _policy = VTraceTFPolicy - - @override(Trainer) - def _init(self, config, env_creator): - for k in OPTIMIZER_SHARED_CONFIGS: - if k not in config["optimizer"]: - config["optimizer"][k] = config[k] - policy_cls = self._get_policy() - self.workers = self._make_workers( - self.env_creator, policy_cls, self.config, num_workers=0) - - if self.config["num_aggregation_workers"] > 0: - # Create co-located aggregator actors first for placement pref - aggregators = TreeAggregator.precreate_aggregators( - self.config["num_aggregation_workers"]) - - self.workers.add_workers(config["num_workers"]) - self.optimizer = AsyncSamplesOptimizer(self.workers, - **config["optimizer"]) - if config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - - if self.config["num_aggregation_workers"] > 0: - # Assign the pre-created aggregators to the optimizer - self.optimizer.aggregator.init(aggregators) - +def choose_policy(config): + if config["vtrace"]: + return VTraceTFPolicy + else: + return A3CTFPolicy + + +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") + + +def defer_make_workers(trainer, env_creator, policy, config): + # Defer worker creation to after the optimizer has been created. + return trainer._make_workers(env_creator, policy, config, 0) + + +def make_aggregators_and_optimizer(workers, config): + if config["num_aggregation_workers"] > 0: + # Create co-located aggregator actors first for placement pref + aggregators = TreeAggregator.precreate_aggregators( + config["num_aggregation_workers"]) + else: + aggregators = None + workers.add_workers(config["num_workers"]) + + optimizer = AsyncSamplesOptimizer( + workers, + lr=config["lr"], + num_envs_per_worker=config["num_envs_per_worker"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + train_batch_size=config["train_batch_size"], + replay_buffer_num_slots=config["replay_buffer_num_slots"], + replay_proportion=config["replay_proportion"], + num_data_loader_buffers=config["num_data_loader_buffers"], + max_sample_requests_in_flight_per_worker=config[ + "max_sample_requests_in_flight_per_worker"], + broadcast_interval=config["broadcast_interval"], + num_sgd_iter=config["num_sgd_iter"], + minibatch_buffer_size=config["minibatch_buffer_size"], + num_aggregation_workers=config["num_aggregation_workers"], + **config["optimizer"]) + + if aggregators: + # Assign the pre-created aggregators to the optimizer + optimizer.aggregator.init(aggregators) + return optimizer + + +class OverrideDefaultResourceRequest(object): @classmethod @override(Trainable) def default_resource_request(cls, config): @@ -143,22 +146,13 @@ def default_resource_request(cls, config): cf["num_aggregation_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - start = time.time() - self.optimizer.step() - while (time.time() - start < self.config["min_iter_time_s"] - or self.optimizer.num_steps_sampled == prev_steps): - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result - - def _get_policy(self): - if self.config["vtrace"]: - policy_cls = self._policy - else: - policy_cls = A3CTFPolicy - return policy_cls + +ImpalaTrainer = build_trainer( + name="IMPALA", + default_config=DEFAULT_CONFIG, + default_policy=VTraceTFPolicy, + validate_config=validate_config, + get_policy_class=choose_policy, + make_workers=defer_make_workers, + make_policy_optimizer=make_aggregators_and_optimizer, + mixins=[OverrideDefaultResourceRequest]) diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index b8e01806ca29..29be38a84c32 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -2,10 +2,10 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -39,30 +39,17 @@ # yapf: enable -class MARWILTrainer(Trainer): - """MARWIL implementation in TensorFlow.""" +def make_optimizer(workers, config): + return SyncBatchReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["replay_buffer_size"], + train_batch_size=config["train_batch_size"], + ) - _name = "MARWIL" - _default_config = DEFAULT_CONFIG - _policy = MARWILPolicy - @override(Trainer) - def _init(self, config, env_creator): - self.workers = self._make_workers(env_creator, self._policy, config, - config["num_workers"]) - self.optimizer = SyncBatchReplayOptimizer( - self.workers, - learning_starts=config["learning_starts"], - buffer_size=config["replay_buffer_size"], - train_batch_size=config["train_batch_size"], - ) - - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - res = self.collect_metrics() - res.update( - timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=dict(fetches, **res.get("info", {}))) - return res +MARWILTrainer = build_trainer( + name="MARWIL", + default_config=DEFAULT_CONFIG, + default_policy=MARWILPolicy, + make_policy_optimizer=make_optimizer) diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index 0438b2714221..4b0d9945dec3 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -5,7 +5,6 @@ from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -51,14 +50,8 @@ # __sphinx_doc_end__ # yapf: enable - -class APPOTrainer(impala.ImpalaTrainer): - """PPO surrogate loss with IMPALA-architecture.""" - - _name = "APPO" - _default_config = DEFAULT_CONFIG - _policy = AsyncPPOTFPolicy - - @override(impala.ImpalaTrainer) - def _get_policy(self): - return AsyncPPOTFPolicy +APPOTrainer = impala.ImpalaTrainer.with_updates( + name="APPO", + default_config=DEFAULT_CONFIG, + default_policy=AsyncPPOTFPolicy, + get_policy_class=lambda _: AsyncPPOTFPolicy) diff --git a/python/ray/rllib/agents/ppo/ppo_policy.py b/python/ray/rllib/agents/ppo/ppo_policy.py index 4b391cab2cdc..ad79d90faa9a 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_policy.py @@ -106,8 +106,10 @@ def reduce_mean_valid(t): def ppo_surrogate_loss(policy, batch_tensors): if policy.model.state_in: - max_seq_len = tf.reduce_max(policy.model.seq_lens) - mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + max_seq_len = tf.reduce_max( + policy.convert_to_eager(policy.model.seq_lens)) + mask = tf.sequence_mask( + policy.convert_to_eager(policy.model.seq_lens), max_seq_len) mask = tf.reshape(mask, [-1]) else: mask = tf.ones_like( @@ -121,8 +123,8 @@ def ppo_surrogate_loss(policy, batch_tensors): batch_tensors[BEHAVIOUR_LOGITS], batch_tensors[SampleBatch.VF_PREDS], policy.action_dist, - policy.value_function, - policy.kl_coeff, + policy.convert_to_eager(policy.value_function), + policy.convert_to_eager(policy.kl_coeff), mask, entropy_coeff=policy.config["entropy_coeff"], clip_param=policy.config["clip_param"], diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py index 65c91d655af2..aac5d83f726a 100644 --- a/python/ray/rllib/agents/qmix/apex.py +++ b/python/ray/rllib/agents/qmix/apex.py @@ -4,15 +4,14 @@ from __future__ import division from __future__ import print_function +from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES from ray.rllib.agents.qmix.qmix import QMixTrainer, \ DEFAULT_CONFIG as QMIX_CONFIG -from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_QMIX_DEFAULT_CONFIG = merge_dicts( QMIX_CONFIG, # see also the options in qmix.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( QMIX_CONFIG["optimizer"], { @@ -34,23 +33,7 @@ }, ) - -class ApexQMixTrainer(QMixTrainer): - """QMIX variant that uses the Ape-X distributed policy optimizer. - - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - - _name = "APEX_QMIX" - _default_config = APEX_QMIX_DEFAULT_CONFIG - - @override(QMixTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 +ApexQMixTrainer = QMixTrainer.with_updates( + name="APEX_QMIX", + default_config=APEX_QMIX_DEFAULT_CONFIG, + **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py index 2ad6a3e56f95..6a5bff9d63e8 100644 --- a/python/ray/rllib/agents/qmix/qmix.py +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -3,8 +3,9 @@ from __future__ import print_function from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy +from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable # __sphinx_doc_begin__ @@ -71,8 +72,6 @@ # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncBatchReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -90,12 +89,16 @@ # yapf: enable -class QMixTrainer(DQNTrainer): - """QMix implementation in PyTorch.""" +def make_sync_batch_optimizer(workers, config): + return SyncBatchReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + train_batch_size=config["train_batch_size"]) - _name = "QMIX" - _default_config = DEFAULT_CONFIG - _policy = QMixTorchPolicy - _optimizer_shared_configs = [ - "learning_starts", "buffer_size", "train_batch_size" - ] + +QMixTrainer = GenericOffPolicyTrainer.with_updates( + name="QMIX", + default_config=DEFAULT_CONFIG, + default_policy=QMixTorchPolicy, + make_policy_optimizer=make_sync_batch_optimizer) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index ee6ff7f5de3d..6cf3741ddedc 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -9,7 +9,6 @@ import tempfile import time from datetime import datetime -from types import FunctionType import ray import six @@ -17,8 +16,6 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.models import MODEL_DEFAULTS -from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ - ShuffledInput from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import FilterManager, deep_update, merge_dicts @@ -72,6 +69,9 @@ # Whether to attempt to continue training if a worker crashes. "ignore_worker_failures": False, "log_sys_usage": False, + # Execute TF loss functions in eager mode. This is currently experimental + # and only really works with the basic PG algorithm. + "use_eager": False, # === Policy === # Arguments to pass to model. See models/catalog.py for a full list of the @@ -196,6 +196,9 @@ "remote_env_batch_wait_ms": 0, # Minimum time per iteration "min_iter_time_s": 0, + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. + "timesteps_per_iteration": 0, # === Offline Datasets === # Specify how to generate experiences: @@ -509,6 +512,7 @@ def _evaluate(self): logger.info("Evaluating current policy for {} episodes".format( self.config["evaluation_num_episodes"])) + self._before_evaluate() self.evaluation_workers.local_worker().restore( self.workers.local_worker().save()) for _ in range(self.config["evaluation_num_episodes"]): @@ -517,6 +521,11 @@ def _evaluate(self): metrics = collect_metrics(self.evaluation_workers.local_worker()) return {"evaluation": metrics} + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + @PublicAPI def compute_action(self, observation, diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py index 6af9e1c781e0..ee0b4181c337 100644 --- a/python/ray/rllib/agents/trainer_template.py +++ b/python/ray/rllib/agents/trainer_template.py @@ -6,6 +6,7 @@ from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -13,25 +14,47 @@ def build_trainer(name, default_policy, default_config=None, - make_policy_optimizer=None, validate_config=None, + get_initial_state=None, get_policy_class=None, + before_init=None, + make_workers=None, + make_policy_optimizer=None, + after_init=None, before_train_step=None, after_optimizer_step=None, - after_train_result=None): + after_train_result=None, + collect_metrics_fn=None, + before_evaluate_fn=None, + mixins=None): """Helper function for defining a custom trainer. + Functions will be run in this order to initialize the trainer: + 1. Config setup: validate_config, get_initial_state, get_policy + 2. Worker setup: before_init, make_workers, make_policy_optimizer + 3. Post setup: after_init + Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): the default config dict of the algorithm, otherwises uses the Trainer default config - make_policy_optimizer (func): optional function that returns a - PolicyOptimizer instance given (WorkerSet, config) validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. + get_initial_state (func): optional function that returns the initial + state dict given the trainer instance as an argument. The state + dict must be serializable so that it can be checkpointed, and will + be available as the `trainer.state` variable. get_policy_class (func): optional callback that takes a config and returns the policy class to override the default with + before_init (func): optional function to run at the start of trainer + init that takes the trainer instance as argument + make_workers (func): override the method that creates rollout workers. + This takes in (trainer, env_creator, policy, config) as args. + make_policy_optimizer (func): optional function that returns a + PolicyOptimizer instance given (WorkerSet, config) + after_init (func): optional function to run at the end of trainer init + that takes the trainer instance as argument before_train_step (func): optional callback to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (func): optional callback to run after each @@ -40,27 +63,47 @@ def build_trainer(name, after_train_result (func): optional callback to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed. + collect_metrics_fn (func): override the method used to collect metrics. + It takes the trainer instance as argumnt. + before_evaluate_fn (func): callback to run before evaluation. This + takes the trainer instance as argument. + mixins (list): list of any class mixins for the returned trainer class. + These mixins will be applied in order and will have higher + precedence than the Trainer class Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() + base = add_mixins(Trainer, mixins) - class trainer_cls(Trainer): + class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy + def __init__(self, config=None, env=None, logger_creator=None): + Trainer.__init__(self, config, env, logger_creator) + def _init(self, config, env_creator): if validate_config: validate_config(config) + if get_initial_state: + self.state = get_initial_state(self) + else: + self.state = {} if get_policy_class is None: policy = default_policy else: policy = get_policy_class(config) - self.workers = self._make_workers(env_creator, policy, config, - self.config["num_workers"]) + if before_init: + before_init(self) + if make_workers: + self.workers = make_workers(self, env_creator, policy, config) + else: + self.workers = self._make_workers(env_creator, policy, config, + self.config["num_workers"]) if make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: @@ -69,6 +112,8 @@ def _init(self, config, env_creator): **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) + if after_init: + after_init(self) @override(Trainer) def _train(self): @@ -81,20 +126,46 @@ def _train(self): fetches = self.optimizer.step() if after_optimizer_step: after_optimizer_step(self, fetches) - if time.time() - start > self.config["min_iter_time_s"]: + if (time.time() - start >= self.config["min_iter_time_s"] + and self.optimizer.num_steps_sampled - prev_steps >= + self.config["timesteps_per_iteration"]): break - res = self.collect_metrics() + if collect_metrics_fn: + res = collect_metrics_fn(self) + else: + res = self.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) + if after_train_result: after_train_result(self, res) return res + @override(Trainer) + def _before_evaluate(self): + if before_evaluate_fn: + before_evaluate_fn(self) + + def __getstate__(self): + state = Trainer.__getstate__(self) + state.update(self.state) + return state + + def __setstate__(self, state): + Trainer.__setstate__(self, state) + self.state = state + @staticmethod def with_updates(**overrides): + """Build a copy of this trainer with the specified overrides. + + Arguments: + overrides (dict): use this to override any of the arguments + originally passed to build_trainer() for this policy. + """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = with_updates diff --git a/python/ray/rllib/examples/eager_execution.py b/python/ray/rllib/examples/eager_execution.py new file mode 100644 index 000000000000..a3c418a33139 --- /dev/null +++ b/python/ray/rllib/examples/eager_execution.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import random + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--iters", type=int, default=200) + + +class EagerModel(Model): + """Example of using embedded eager execution in a custom model. + + This shows how to use tf.py_function() to execute a snippet of TF code + in eager mode. Here the `self.forward_eager` method just prints out + the intermediate tensor for debug purposes, but you can in general + perform any TF eager operation in tf.py_function(). + """ + + def _build_layers_v2(self, input_dict, num_outputs, options): + self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space, + self.action_space, num_outputs, + options) + feature_out = tf.py_function(self.forward_eager, + [self.fcnet.last_layer], tf.float32) + + with tf.control_dependencies([feature_out]): + return tf.identity(self.fcnet.outputs), feature_out + + def forward_eager(self, feature_layer): + assert tf.executing_eagerly() + if random.random() > 0.99: + print("Eagerly printing the feature layer mean value", + tf.reduce_mean(feature_layer)) + return feature_layer + + +def policy_gradient_loss(policy, batch_tensors): + """Example of using embedded eager execution in a custom loss. + + Here `compute_penalty` prints the actions and rewards for debugging, and + also computes a (dummy) penalty term to add to the loss. + + Alternatively, you can set config["use_eager"] = True, which will try to + automatically eagerify the entire loss function. However, this only works + if your loss doesn't reference any non-eager tensors. It also won't work + with the multi-GPU optimizer used by PPO. + """ + + def compute_penalty(actions, rewards): + assert tf.executing_eagerly() + penalty = tf.reduce_mean(tf.cast(actions, tf.float32)) + if random.random() > 0.9: + print("The eagerly computed penalty is", penalty, actions, rewards) + return penalty + + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + penalty = tf.py_function( + compute_penalty, [actions, rewards], Tout=tf.float32) + + return penalty - tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + + +# +MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss, +) + +# +MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTFPolicy, +) + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + ModelCatalog.register_custom_model("eager_model", EagerModel) + tune.run( + MyTrainer, + stop={"training_iteration": args.iters}, + config={ + "env": "CartPole-v0", + "num_workers": 0, + "model": { + "custom_model": "eager_model" + }, + }) diff --git a/python/ray/rllib/examples/saving_experiences.py b/python/ray/rllib/examples/saving_experiences.py index 7a29b0fe7b0d..d2de88302d23 100644 --- a/python/ray/rllib/examples/saving_experiences.py +++ b/python/ray/rllib/examples/saving_experiences.py @@ -7,6 +7,7 @@ import gym import numpy as np +from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder from ray.rllib.offline.json_writer import JsonWriter @@ -18,6 +19,12 @@ # simulator is available, but let's do it anyways for example purposes: env = gym.make("CartPole-v0") + # RLlib uses preprocessors to implement transforms such as one-hot encoding + # and flattening of tuple and dict observations. For CartPole a no-op + # preprocessor is used, but this may be relevant for more complex envs. + prep = get_preprocessor(env.observation_space)(env.observation_space) + print("The preprocessor is", prep) + for eps_id in range(100): obs = env.reset() prev_action = np.zeros_like(env.action_space.sample()) @@ -31,7 +38,7 @@ t=t, eps_id=eps_id, agent_index=0, - obs=obs, + obs=prep.transform(obs), actions=action, action_prob=1.0, # put the true action probability here rewards=rew, @@ -39,7 +46,7 @@ prev_rewards=prev_reward, dones=done, infos=info, - new_obs=new_obs) + new_obs=prep.transform(new_obs)) obs = new_obs prev_action = action prev_reward = rew diff --git a/python/ray/rllib/policy/dynamic_tf_policy.py b/python/ray/rllib/policy/dynamic_tf_policy.py index 0240f275de37..23014553bf0d 100644 --- a/python/ray/rllib/policy/dynamic_tf_policy.py +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -167,6 +167,8 @@ def __init__(self, batch_divisibility_req=batch_divisibility_req) # Phase 2 init + self._needs_eager_conversion = set() + self._eager_tensors = {} before_loss_init(self, obs_space, action_space, config) if not existing_inputs: self._initialize_loss() @@ -178,10 +180,26 @@ def get_obs_input_dict(self): """ return self.input_dict + def convert_to_eager(self, tensor): + """Convert a graph tensor accessed in the loss to an eager tensor. + + Experimental. + """ + if tf.executing_eagerly(): + return self._eager_tensors[tensor] + else: + self._needs_eager_conversion.add(tensor) + return tensor + @override(TFPolicy) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" + if self.config["use_eager"]: + raise ValueError( + "eager not implemented for multi-GPU, try setting " + "`simple_optimizer: true`") + # Note that there might be RNN state inputs at the end of the list if self._state_inputs: num_state_inputs = len(self._state_inputs) + 1 @@ -297,6 +315,38 @@ def fake_array(tensor): loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) + + # XXX experimental support for automatically eagerifying the loss. + # The main limitation right now is that TF doesn't support mixing eager + # and non-eager tensors, so losses that read non-eager tensors through + # `policy` need to use `policy.convert_to_eager(tensor)`. + if self.config["use_eager"]: + if not self.model: + raise ValueError("eager not implemented in this case") + graph_tensors = list(self._needs_eager_conversion) + + def gen_loss(model_outputs, *args): + # fill in the batch tensor dict with eager ensors + eager_inputs = dict( + zip([k for (k, v) in loss_inputs], + args[:len(loss_inputs)])) + # fill in the eager versions of all accessed graph tensors + self._eager_tensors = dict( + zip(graph_tensors, args[len(loss_inputs):])) + # patch the action dist to use eager mode tensors + self.action_dist.inputs = model_outputs + return self._loss_fn(self, eager_inputs) + + # TODO(ekl) also handle the stats funcs + loss = tf.py_function( + gen_loss, + # cast works around TypeError: Cannot convert provided value + # to EagerTensor. Provided value: 0.0 Requested dtype: int64 + [self.model.outputs] + [ + tf.cast(v, tf.float32) for (k, v) in loss_inputs + ] + [tf.cast(t, tf.float32) for t in graph_tensors], + tf.float32) + TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index 9e5ea301d868..abc5cf546184 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -205,7 +205,7 @@ def _initialize_loss(self, loss, loss_inputs): self._grads_and_vars) if log_once("loss_used"): - logger.info( + logger.debug( "These tensors were used in the loss_fn:\n\n{}\n".format( summarize(self._loss_input_dict))) diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py index b7f33fcb0887..37828bfe18b0 100644 --- a/python/ray/rllib/policy/tf_policy_template.py +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -5,6 +5,7 @@ from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -89,13 +90,7 @@ def build_tf_policy(name, """ original_kwargs = locals().copy() - base = DynamicTFPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base + base = add_mixins(DynamicTFPolicy, mixins) class policy_cls(base): def __init__(self, diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index 1f4185f9c12e..f1b0c0c682d6 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -5,6 +5,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -56,13 +57,7 @@ def build_torch_policy(name, """ original_kwargs = locals().copy() - base = TorchPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base + base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): diff --git a/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml b/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml index 6a4bd52e77fe..0513f7bf6ef1 100644 --- a/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml @@ -47,7 +47,6 @@ halfcheetah-ddpg: # === Parallelism === num_workers: 0 num_gpus_per_worker: 0 - optimizer_class: "SyncReplayOptimizer" per_worker_exploration: False worker_side_prioritization: False diff --git a/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml b/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml index 3a8f61229224..87ce8eff58cc 100644 --- a/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml @@ -47,7 +47,6 @@ mountaincarcontinuous-ddpg: # === Parallelism === num_workers: 0 num_gpus_per_worker: 0 - optimizer_class: "SyncReplayOptimizer" per_worker_exploration: False worker_side_prioritization: False diff --git a/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml b/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml index 59891a86b6bc..a2ad295fb4c0 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml @@ -47,7 +47,6 @@ pendulum-ddpg: # === Parallelism === num_workers: 0 num_gpus_per_worker: 0 - optimizer_class: "SyncReplayOptimizer" per_worker_exploration: False worker_side_prioritization: False diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index aad5590fd097..bde901e22a9c 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -27,6 +27,21 @@ def __init__(self, *args, **kw): return DeprecationWrapper +def add_mixins(base, mixins): + """Returns a new class with mixins applied in priority order.""" + + mixins = list(mixins or []) + + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + return base + + def renamed_agent(cls): """Helper class for renaming Agent => Trainer with a warning.""" diff --git a/python/ray/services.py b/python/ray/services.py index d6b54f475453..66d4069820d0 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1206,6 +1206,7 @@ def start_raylet(redis_address, "--java_worker_command={}".format(java_worker_command), "--redis_password={}".format(redis_password or ""), "--temp_dir={}".format(temp_dir), + "--session_dir={}".format(session_dir), ] process_info = start_ray_process( command, diff --git a/python/ray/state.py b/python/ray/state.py index 6b2c8a4ef8bc..14ba49987ec4 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -41,7 +41,7 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) ordered_client_ids = [] @@ -248,8 +248,7 @@ def _object_table(self, object_id): object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) assert gcs_entry.EntriesLength() > 0 @@ -307,8 +306,7 @@ def _task_table(self, task_id): "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) assert gcs_entries.EntriesLength() == 1 @@ -431,8 +429,7 @@ def _profile_table(self, batch_id): if message is None: return [] - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) profile_events = [] for i in range(gcs_entries.EntriesLength()): @@ -815,9 +812,8 @@ def available_resources(self): ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = ( - ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0)) + gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( + data, 0)) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatTableData. GetRootAsHeartbeatTableData(heartbeat_data, 0)) @@ -871,8 +867,7 @@ def _error_messages(self, driver_id): if message is None: return [] - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) error_messages = [] for i in range(gcs_entries.EntriesLength()): error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( @@ -934,8 +929,7 @@ def actor_checkpoint_info(self, actor_id): ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) entry = ( ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( gcs_entry.Entries(0), 0)) diff --git a/python/ray/tests/perf_integration_tests/test_perf_integration.py b/python/ray/tests/perf_integration_tests/test_perf_integration.py index 2ce2a305a0e8..ff34fe4125fa 100644 --- a/python/ray/tests/perf_integration_tests/test_perf_integration.py +++ b/python/ray/tests/perf_integration_tests/test_perf_integration.py @@ -6,6 +6,7 @@ import pytest import ray +from ray.tests.conftest import _ray_start_cluster num_tasks_submitted = [10**n for n in range(0, 6)] num_tasks_ids = ["{}_tasks".format(i) for i in num_tasks_submitted] @@ -41,3 +42,25 @@ def test_task_submission(benchmark, num_tasks): warmup() benchmark(benchmark_task_submission, num_tasks) ray.shutdown() + + +def benchmark_task_forward(f, num_tasks): + ray.get([f.remote() for _ in range(num_tasks)]) + + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "num_tasks", [10**3, 10**4], + ids=[str(num) + "_tasks" for num in [10**3, 10**4]]) +def test_task_forward(benchmark, num_tasks): + with _ray_start_cluster(num_cpus=16, object_store_memory=10**6) as cluster: + cluster.add_node(resources={"my_resource": 100}) + ray.init(redis_address=cluster.redis_address) + + @ray.remote(resources={"my_resource": 0.001}) + def f(): + return 1 + + # Warm up + ray.get([f.remote() for _ in range(100)]) + benchmark(benchmark_task_forward, f, num_tasks) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 50aeca025362..7f1f78d1b5c4 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1754,7 +1754,7 @@ def f(n): def g(n): time.sleep(n) - time_buffer = 0.5 + time_buffer = 2 start_time = time.time() ray.get([f.remote(0.5), g.remote(0.5)]) @@ -1878,13 +1878,23 @@ def test(self): def test_zero_cpus(shutdown_only): ray.init(num_cpus=0) + # We should be able to execute a task that requires 0 CPU resources. @ray.remote(num_cpus=0) def f(): return 1 - # The task should be able to execute. ray.get(f.remote()) + # We should be able to create an actor that requires 0 CPU resources. + @ray.remote(num_cpus=0) + class Actor(object): + def method(self): + pass + + a = Actor.remote() + x = a.method.remote() + ray.get(x) + def test_zero_cpus_actor(ray_start_cluster): cluster = ray_start_cluster diff --git a/python/ray/tune/requirements-dev.txt b/python/ray/tune/requirements-dev.txt new file mode 100644 index 000000000000..9d3d3ddab12f --- /dev/null +++ b/python/ray/tune/requirements-dev.txt @@ -0,0 +1,9 @@ +flake8 +flake8-quotes +gym +opencv-python +pandas +requests +tabulate +tensorflow +yapf==0.23.0 diff --git a/python/ray/worker.py b/python/ray/worker.py index 7786c742d9b1..7505120574a6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1656,7 +1656,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( msg["data"], 0) assert gcs_entry.EntriesLength() == 1 error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( diff --git a/python/setup.py b/python/setup.py index 11a2a04e7424..eb200ea7d5e4 100644 --- a/python/setup.py +++ b/python/setup.py @@ -149,8 +149,6 @@ def find_version(*filepath): # NOTE: Don't upgrade the version of six! Doing so causes installation # problems. See https://github.com/ray-project/ray/issues/4169. "six >= 1.0.0", - # The typing module is required by modin. - "typing", "flatbuffers", "faulthandler;python_version<'3.3'", ] diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 57e41d97d10c..3928d4adfcb7 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -105,8 +105,8 @@ ObjectID ObjectID::ForPut(const TaskID &task_id, int64_t put_index) { } ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, int64_t return_index) { - RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) << "index=" - << return_index; + RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) + << "index=" << return_index; ObjectID object_id; std::memcpy(object_id.id_, task_id.Binary().c_str(), task_id.Size()); object_id.index_ = return_index; diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 8317bf181207..3fda406613ef 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -5,6 +5,8 @@ #include "ray/common/buffer.h" #include "ray/common/id.h" +#include "ray/raylet/raylet_client.h" +#include "ray/raylet/task_spec.h" namespace ray { @@ -12,12 +14,12 @@ namespace ray { enum class WorkerType { WORKER, DRIVER }; /// Language of Ray tasks and workers. -enum class Language { PYTHON, JAVA }; +enum class WorkerLanguage { PYTHON, JAVA }; /// Information about a remote function. struct RayFunction { /// Language of the remote function. - const Language language; + const WorkerLanguage language; /// Function descriptor of the remote function. const std::vector function_descriptor; }; @@ -66,6 +68,35 @@ class TaskArg { const std::shared_ptr data_; }; +/// Task specification, which includes the immutable information about the task +/// which are determined at the submission time. +/// TODO(zhijunfu): this can be removed after everything is moved to protobuf. +class TaskSpec { + public: + TaskSpec(const raylet::TaskSpecification &task_spec, + const std::vector &dependencies) + : task_spec_(task_spec), dependencies_(dependencies) {} + + TaskSpec(const raylet::TaskSpecification &&task_spec, + const std::vector &&dependencies) + : task_spec_(task_spec), dependencies_(dependencies) {} + + const raylet::TaskSpecification &GetTaskSpecification() const { return task_spec_; } + + const std::vector &GetDependencies() const { return dependencies_; } + + private: + /// Raylet task specification. + raylet::TaskSpecification task_spec_; + + /// Dependencies. + std::vector dependencies_; +}; + +enum class StoreProviderType { PLASMA }; + +enum class TaskTransportType { RAYLET }; + } // namespace ray #endif // RAY_CORE_WORKER_COMMON_H diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index fedcfc6625d9..717c52e07076 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -1,5 +1,5 @@ -#include "context.h" +#include "ray/core_worker/context.h" namespace ray { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 6e0cf3f9f2cf..932d02891b6a 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -1,7 +1,7 @@ #ifndef RAY_CORE_WORKER_CONTEXT_H #define RAY_CORE_WORKER_CONTEXT_H -#include "common.h" +#include "ray/core_worker/common.h" #include "ray/raylet/task_spec.h" namespace ray { diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 82f2d885ec58..bcc1bdd963db 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1,39 +1,47 @@ -#include "core_worker.h" -#include "context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/context.h" namespace ray { -CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language, +CoreWorker::CoreWorker(const enum WorkerType worker_type, + const enum WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, DriverID driver_id) : worker_type_(worker_type), language_(language), - worker_context_(worker_type, driver_id), store_socket_(store_socket), raylet_socket_(raylet_socket), + worker_context_(worker_type, driver_id), + raylet_client_(raylet_socket_, worker_context_.GetWorkerID(), + (worker_type_ == ray::WorkerType::WORKER), + worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)), task_interface_(*this), object_interface_(*this), - task_execution_interface_(*this) {} - -Status CoreWorker::Connect() { - // connect to plasma. - RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_)); - - // connect to raylet. - ::Language lang = ::Language::PYTHON; - if (language_ == ray::Language::JAVA) { - lang = ::Language::JAVA; - } - - // TODO: currently RayletClient would crash in its constructor if it cannot + task_execution_interface_(*this) { + // TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot // connect to Raylet after a number of retries, this needs to be changed // so that the worker (java/python .etc) can retrieve and handle the error // instead of crashing. - raylet_client_ = std::unique_ptr( - new RayletClient(raylet_socket_, worker_context_.GetWorkerID(), - (worker_type_ == ray::WorkerType::WORKER), - worker_context_.GetCurrentDriverID(), lang)); - return Status::OK(); + auto status = store_client_.Connect(store_socket_); + if (!status.ok()) { + RAY_LOG(ERROR) << "Connecting plasma store failed when trying to construct" + << " core worker: " << status.message(); + throw std::runtime_error(status.message()); + } +} + +::Language CoreWorker::ToTaskLanguage(WorkerLanguage language) { + switch (language) { + case ray::WorkerLanguage::JAVA: + return ::Language::JAVA; + break; + case ray::WorkerLanguage::PYTHON: + return ::Language::PYTHON; + break; + default: + RAY_LOG(FATAL) << "invalid language specified: " << static_cast(language); + break; + } } } // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 951b55451f09..e03a8700be81 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1,13 +1,13 @@ #ifndef RAY_CORE_WORKER_CORE_WORKER_H #define RAY_CORE_WORKER_CORE_WORKER_H -#include "common.h" -#include "context.h" -#include "object_interface.h" #include "ray/common/buffer.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/object_interface.h" +#include "ray/core_worker/task_execution.h" +#include "ray/core_worker/task_interface.h" #include "ray/raylet/raylet_client.h" -#include "task_execution.h" -#include "task_interface.h" namespace ray { @@ -20,18 +20,17 @@ class CoreWorker { /// /// \param[in] worker_type Type of this worker. /// \param[in] langauge Language of this worker. - CoreWorker(const WorkerType worker_type, const Language language, + /// + /// NOTE(zhijunfu): the constructor would throw if a failure happens. + CoreWorker(const WorkerType worker_type, const WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, DriverID driver_id = DriverID::Nil()); - /// Connect to raylet. - Status Connect(); - /// Type of this worker. enum WorkerType WorkerType() const { return worker_type_; } /// Language of this worker. - enum Language Language() const { return language_; } + enum WorkerLanguage Language() const { return language_; } /// Return the `CoreWorkerTaskInterface` that contains the methods related to task /// submisson. @@ -46,26 +45,35 @@ class CoreWorker { CoreWorkerTaskExecutionInterface &Execution() { return task_execution_interface_; } private: + /// Translate from WorkLanguage to Language type (required by raylet client). + /// + /// \param[in] language Language for a task. + /// \return Translated task language. + ::Language ToTaskLanguage(WorkerLanguage language); + /// Type of this worker. const enum WorkerType worker_type_; /// Language of this worker. - const enum Language language_; - - /// Worker context per thread. - WorkerContext worker_context_; + const enum WorkerLanguage language_; /// Plasma store socket name. - std::string store_socket_; + const std::string store_socket_; /// raylet socket name. - std::string raylet_socket_; + const std::string raylet_socket_; + + /// Worker context. + WorkerContext worker_context_; /// Plasma store client. plasma::PlasmaClient store_client_; + /// Mutex to protect store_client_. + std::mutex store_client_mutex_; + /// Raylet client. - std::unique_ptr raylet_client_; + RayletClient raylet_client_; /// The `CoreWorkerTaskInterface` instance. CoreWorkerTaskInterface task_interface_; diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index e440aae24d67..6e4ecc161fb4 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -2,9 +2,9 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "context.h" -#include "core_worker.h" #include "ray/common/buffer.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" #include "ray/raylet/raylet_client.h" #include @@ -18,6 +18,7 @@ namespace ray { std::string store_executable; std::string raylet_executable; +std::string mock_worker_executable; ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); } @@ -32,6 +33,9 @@ static void flushall_redis(void) { class CoreWorkerTest : public ::testing::Test { public: CoreWorkerTest(int num_nodes) { + // flush redis first. + flushall_redis(); + RAY_CHECK(num_nodes >= 0); if (num_nodes > 0) { raylet_socket_names_.resize(num_nodes); @@ -43,10 +47,12 @@ class CoreWorkerTest : public ::testing::Test { store_socket = StartStore(); } - // start raylet on each node + // start raylet on each node. Assign each node with different resources so that + // a task can be scheduled to the desired node. for (int i = 0; i < num_nodes; i++) { - raylet_socket_names_[i] = StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", - "127.0.0.1", "\"CPU,4.0\""); + raylet_socket_names_[i] = + StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", "127.0.0.1", + "\"CPU,4.0,resource" + std::to_string(i) + ",10\""); } } @@ -66,7 +72,7 @@ class CoreWorkerTest : public ::testing::Test { std::string plasma_command = store_executable + " -m 10000000 -s " + store_socket_name + " 1> /dev/null 2> /dev/null & echo $! > " + store_pid; - RAY_LOG(INFO) << plasma_command; + RAY_LOG(DEBUG) << plasma_command; RAY_CHECK(system(plasma_command.c_str()) == 0); usleep(200 * 1000); return store_socket_name; @@ -75,7 +81,7 @@ class CoreWorkerTest : public ::testing::Test { void StopStore(std::string store_socket_name) { std::string store_pid = store_socket_name + ".pid"; std::string kill_9 = "kill -9 `cat " + store_pid + "`"; - RAY_LOG(INFO) << kill_9; + RAY_LOG(DEBUG) << kill_9; ASSERT_TRUE(system(kill_9.c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + store_socket_name).c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + store_socket_name + ".pid").c_str()) == 0); @@ -91,13 +97,14 @@ class CoreWorkerTest : public ::testing::Test { .append(" --node_ip_address=" + node_ip_address) .append(" --redis_address=" + redis_address) .append(" --redis_port=6379") - .append(" --num_initial_workers=0") + .append(" --num_initial_workers=1") .append(" --maximum_startup_concurrency=10") .append(" --static_resource_list=" + resource) - .append(" --python_worker_command=NoneCmd") + .append(" --python_worker_command=\"" + mock_worker_executable + " " + + store_socket_name + " " + raylet_socket_name + "\"") .append(" & echo $! > " + raylet_socket_name + ".pid"); - RAY_LOG(INFO) << "Ray Start command: " << ray_start_cmd; + RAY_LOG(DEBUG) << "Ray Start command: " << ray_start_cmd; RAY_CHECK(system(ray_start_cmd.c_str()) == 0); usleep(200 * 1000); return raylet_socket_name; @@ -106,16 +113,131 @@ class CoreWorkerTest : public ::testing::Test { void StopRaylet(std::string raylet_socket_name) { std::string raylet_pid = raylet_socket_name + ".pid"; std::string kill_9 = "kill -9 `cat " + raylet_pid + "`"; - RAY_LOG(INFO) << kill_9; + RAY_LOG(DEBUG) << kill_9; ASSERT_TRUE(system(kill_9.c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0); } - void SetUp() { flushall_redis(); } + void SetUp() {} void TearDown() {} + void TestNormalTask(const std::unordered_map &resources) { + CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + + // Test pass by value. + { + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + + auto buffer1 = std::make_shared(array1, sizeof(array1)); + + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + std::vector args; + args.emplace_back(TaskArg::PassByValue(buffer1)); + + TaskOptions options; + + std::vector return_ids; + RAY_CHECK_OK(driver.Tasks().SubmitTask(func, args, options, &return_ids)); + + ASSERT_EQ(return_ids.size(), 1); + + std::vector> results; + RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0]->Size(), buffer1->Size()); + ASSERT_EQ(memcmp(results[0]->Data(), buffer1->Data(), buffer1->Size()), 0); + } + + // Test pass by reference. + { + uint8_t array1[] = {10, 11, 12, 13, 14, 15}; + auto buffer1 = std::make_shared(array1, sizeof(array1)); + + ObjectID object_id; + RAY_CHECK_OK(driver.Objects().Put(*buffer1, &object_id)); + + std::vector args; + args.emplace_back(TaskArg::PassByReference(object_id)); + + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + TaskOptions options; + + std::vector return_ids; + RAY_CHECK_OK(driver.Tasks().SubmitTask(func, args, options, &return_ids)); + + ASSERT_EQ(return_ids.size(), 1); + + std::vector> results; + RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0]->Size(), buffer1->Size()); + ASSERT_EQ(memcmp(results[0]->Data(), buffer1->Data(), buffer1->Size()), 0); + } + } + + void TestActorTask(const std::unordered_map &resources) { + CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + + std::unique_ptr actor_handle; + + // Test creating actor. + { + uint8_t array[] = {1, 2, 3}; + auto buffer = std::make_shared(array, sizeof(array)); + + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + std::vector args; + args.emplace_back(TaskArg::PassByValue(buffer)); + + ActorCreationOptions actor_options{0, resources}; + + // Create an actor. + RAY_CHECK_OK(driver.Tasks().CreateActor(func, args, actor_options, &actor_handle)); + } + + // Test submitting a task for that actor. + { + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + uint8_t array2[] = {10, 11, 12, 13, 14, 15}; + + auto buffer1 = std::make_shared(array1, sizeof(array1)); + auto buffer2 = std::make_shared(array2, sizeof(array2)); + + ObjectID object_id; + RAY_CHECK_OK(driver.Objects().Put(*buffer1, &object_id)); + + // Create arguments with PassByRef and PassByValue. + std::vector args; + args.emplace_back(TaskArg::PassByReference(object_id)); + args.emplace_back(TaskArg::PassByValue(buffer2)); + + TaskOptions options{1, resources}; + std::vector return_ids; + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + RAY_CHECK_OK(driver.Tasks().SubmitActorTask(*actor_handle, func, args, options, + &return_ids)); + RAY_CHECK(return_ids.size() == 1); + + std::vector> results; + RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0]->Size(), buffer1->Size() + buffer2->Size()); + ASSERT_EQ(memcmp(results[0]->Data(), buffer1->Data(), buffer1->Size()), 0); + ASSERT_EQ( + memcmp(results[0]->Data() + buffer1->Size(), buffer2->Data(), buffer2->Size()), + 0); + } + } + protected: std::vector raylet_socket_names_; std::vector raylet_store_socket_names_; @@ -131,6 +253,11 @@ class SingleNodeTest : public CoreWorkerTest { SingleNodeTest() : CoreWorkerTest(1) {} }; +class TwoNodeTest : public CoreWorkerTest { + public: + TwoNodeTest() : CoreWorkerTest(2) {} +}; + TEST_F(ZeroNodeTest, TestTaskArg) { // Test by-reference argument. ObjectID id = ObjectID::FromRandom(); @@ -147,13 +274,6 @@ TEST_F(ZeroNodeTest, TestTaskArg) { ASSERT_EQ(*data, *buffer); } -TEST_F(ZeroNodeTest, TestAttributeGetters) { - CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "", "", - DriverID::FromRandom()); - ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER); - ASSERT_EQ(core_worker.Language(), Language::PYTHON); -} - TEST_F(ZeroNodeTest, TestWorkerContext) { auto driver_id = DriverID::FromRandom(); @@ -180,10 +300,9 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { } TEST_F(SingleNodeTest, TestObjectInterface) { - CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, + CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], DriverID::FromRandom()); - RAY_CHECK_OK(core_worker.Connect()); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -193,16 +312,16 @@ TEST_F(SingleNodeTest, TestObjectInterface) { buffers.emplace_back(array2, sizeof(array2)); std::vector ids(buffers.size()); - for (int i = 0; i < ids.size(); i++) { - core_worker.Objects().Put(buffers[i], &ids[i]); + for (size_t i = 0; i < ids.size(); i++) { + RAY_CHECK_OK(core_worker.Objects().Put(buffers[i], &ids[i])); } // Test Get(). std::vector> results; - core_worker.Objects().Get(ids, 0, &results); + RAY_CHECK_OK(core_worker.Objects().Get(ids, -1, &results)); ASSERT_EQ(results.size(), 2); - for (int i = 0; i < ids.size(); i++) { + for (size_t i = 0; i < ids.size(); i++) { ASSERT_EQ(results[i]->Size(), buffers[i].Size()); ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0); } @@ -213,34 +332,133 @@ TEST_F(SingleNodeTest, TestObjectInterface) { all_ids.push_back(non_existent_id); std::vector wait_results; - core_worker.Objects().Wait(all_ids, 2, -1, &wait_results); + RAY_CHECK_OK(core_worker.Objects().Wait(all_ids, 2, -1, &wait_results)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); - core_worker.Objects().Wait(all_ids, 3, 100, &wait_results); + RAY_CHECK_OK(core_worker.Objects().Wait(all_ids, 3, 100, &wait_results)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); // Test Delete(). // clear the reference held by PlasmaBuffer. results.clear(); - core_worker.Objects().Delete(ids, true, false); + RAY_CHECK_OK(core_worker.Objects().Delete(ids, true, false)); // Note that Delete() calls RayletClient::FreeObjects and would not // wait for objects being deleted, so wait a while for plasma store // to process the command. usleep(200 * 1000); - core_worker.Objects().Get(ids, 0, &results); + RAY_CHECK_OK(core_worker.Objects().Get(ids, 0, &results)); ASSERT_EQ(results.size(), 2); ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]); } +TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { + CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + + CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[1], raylet_socket_names_[1], + DriverID::FromRandom()); + + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + uint8_t array2[] = {10, 11, 12, 13, 14, 15}; + + std::vector buffers; + buffers.emplace_back(array1, sizeof(array1)); + buffers.emplace_back(array2, sizeof(array2)); + + std::vector ids(buffers.size()); + for (size_t i = 0; i < ids.size(); i++) { + RAY_CHECK_OK(worker1.Objects().Put(buffers[i], &ids[i])); + } + + // Test Get() from remote node. + std::vector> results; + RAY_CHECK_OK(worker2.Objects().Get(ids, -1, &results)); + + ASSERT_EQ(results.size(), 2); + for (size_t i = 0; i < ids.size(); i++) { + ASSERT_EQ(results[i]->Size(), buffers[i].Size()); + ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0); + } + + // Test Wait() from remote node. + ObjectID non_existent_id = ObjectID::FromRandom(); + std::vector all_ids(ids); + all_ids.push_back(non_existent_id); + + std::vector wait_results; + RAY_CHECK_OK(worker2.Objects().Wait(all_ids, 2, -1, &wait_results)); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + RAY_CHECK_OK(worker2.Objects().Wait(all_ids, 3, 100, &wait_results)); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + // Test Delete() from all machines. + // clear the reference held by PlasmaBuffer. + results.clear(); + RAY_CHECK_OK(worker2.Objects().Delete(ids, false, false)); + + // Note that Delete() calls RayletClient::FreeObjects and would not + // wait for objects being deleted, so wait a while for plasma store + // to process the command. + usleep(1000 * 1000); + // Verify objects are deleted from both machines. + RAY_CHECK_OK(worker2.Objects().Get(ids, 0, &results)); + ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(!results[0]); + ASSERT_TRUE(!results[1]); + + RAY_CHECK_OK(worker1.Objects().Get(ids, 0, &results)); + ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(!results[0]); + ASSERT_TRUE(!results[1]); +} + +TEST_F(SingleNodeTest, TestNormalTaskLocal) { + std::unordered_map resources; + TestNormalTask(resources); +} + +TEST_F(TwoNodeTest, TestNormalTaskCrossNodes) { + std::unordered_map resources; + resources.emplace("resource1", 1); + TestNormalTask(resources); +} + +TEST_F(SingleNodeTest, TestActorTaskLocal) { + std::unordered_map resources; + TestActorTask(resources); +} + +TEST_F(TwoNodeTest, TestActorTaskCrossNodes) { + std::unordered_map resources; + resources.emplace("resource1", 1); + TestActorTask(resources); +} + +TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) { + try { + CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", + raylet_socket_names_[0], DriverID::FromRandom()); + } catch (const std::exception &e) { + std::cout << "Caught exception when constructing core worker: " << e.what(); + } +} + } // namespace ray int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); + RAY_CHECK(argc == 4); ray::store_executable = std::string(argv[1]); ray::raylet_executable = std::string(argv[2]); + ray::mock_worker_executable = std::string(argv[3]); return RUN_ALL_TESTS(); } diff --git a/src/ray/core_worker/mock_worker.cc b/src/ray/core_worker/mock_worker.cc new file mode 100644 index 000000000000..a331a0b6ae12 --- /dev/null +++ b/src/ray/core_worker/mock_worker.cc @@ -0,0 +1,63 @@ +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/task_execution.h" + +namespace ray { + +/// A mock C++ worker used by core_worker_test.cc to verify the task submission/execution +/// interfaces in both single node and cross-nodes scenarios. As the raylet client can +/// only +/// be called by a real worker process, core_worker_test.cc has to use this program binary +/// to start the actual worker process, in the test, the task submission interfaces are +/// called +/// in core_worker_test, and task execution interfaces are called in this file, see that +/// test +/// for more details on how this class is used. +class MockWorker { + public: + MockWorker(const std::string &store_socket, const std::string &raylet_socket) + : worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket, + DriverID::FromRandom()) {} + + void Run() { + auto executor_func = [this](const RayFunction &ray_function, + const std::vector> &args, + const TaskID &task_id, int num_returns) { + // Note that this doesn't include dummy object id. + RAY_CHECK(num_returns >= 0); + + // Merge all the content from input args. + std::vector buffer; + for (const auto &arg : args) { + buffer.insert(buffer.end(), arg->Data(), arg->Data() + arg->Size()); + } + + LocalMemoryBuffer memory_buffer(buffer.data(), buffer.size()); + + // Write the merged content to each of return ids. + for (int i = 0; i < num_returns; i++) { + ObjectID id = ObjectID::ForTaskReturn(task_id, i + 1); + RAY_CHECK_OK(worker_.Objects().Put(memory_buffer, id)); + } + return Status::OK(); + }; + + // Start executing tasks. + worker_.Execution().Run(executor_func); + } + + private: + CoreWorker worker_; +}; + +} // namespace ray + +int main(int argc, char **argv) { + RAY_CHECK(argc == 3); + auto store_socket = std::string(argv[1]); + auto raylet_socket = std::string(argv[2]); + + ray::MockWorker worker(store_socket, raylet_socket); + worker.Run(); + return 0; +} diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index 0b94c9d4a747..81777117cd14 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -1,128 +1,53 @@ -#include "object_interface.h" -#include "context.h" -#include "core_worker.h" +#include "ray/core_worker/object_interface.h" #include "ray/common/ray_config.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/store_provider/plasma_store_provider.h" namespace ray { CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker) - : core_worker_(core_worker) {} + : core_worker_(core_worker) { + store_providers_.emplace( + static_cast(StoreProviderType::PLASMA), + std::unique_ptr(new CoreWorkerPlasmaStoreProvider( + core_worker_.store_client_, core_worker_.store_client_mutex_, + core_worker_.raylet_client_))); +} Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) { ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(), core_worker_.worker_context_.GetNextPutIndex()); *object_id = put_id; + return Put(buffer, put_id); +} - auto plasma_id = put_id.ToPlasmaId(); - std::shared_ptr data; - RAY_ARROW_RETURN_NOT_OK( - core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); - memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); - RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id)); - RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id)); - return Status::OK(); +Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID &object_id) { + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Put(buffer, object_id); } Status CoreWorkerObjectInterface::Get(const std::vector &ids, int64_t timeout_ms, std::vector> *results) { - (*results).resize(ids.size(), nullptr); - - bool was_blocked = false; - - std::unordered_map unready; - for (int i = 0; i < ids.size(); i++) { - unready.insert({ids[i], i}); - } - - int num_attempts = 0; - bool should_break = false; - int64_t remaining_timeout = timeout_ms; - // Repeat until we get all objects. - while (!unready.empty() && !should_break) { - std::vector unready_ids; - for (const auto &entry : unready) { - unready_ids.push_back(entry.first); - } - - // For the initial fetch, we only fetch the objects, do not reconstruct them. - bool fetch_only = num_attempts == 0; - if (!fetch_only) { - // If fetch_only is false, this worker will be blocked. - was_blocked = true; - } - - // TODO: can call `fetchOrReconstruct` in batches as an optimization. - RAY_CHECK_OK(core_worker_.raylet_client_->FetchOrReconstruct( - unready_ids, fetch_only, core_worker_.worker_context_.GetCurrentTaskID())); - - // Get the objects from the object store, and parse the result. - int64_t get_timeout; - if (remaining_timeout >= 0) { - get_timeout = - std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds()); - remaining_timeout -= get_timeout; - should_break = remaining_timeout <= 0; - } else { - get_timeout = RayConfig::instance().get_timeout_milliseconds(); - } - - std::vector plasma_ids; - for (const auto &id : unready_ids) { - plasma_ids.push_back(id.ToPlasmaId()); - } - - std::vector object_buffers; - auto status = - core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers); - - for (int i = 0; i < object_buffers.size(); i++) { - if (object_buffers[i].data != nullptr) { - const auto &object_id = unready_ids[i]; - (*results)[unready[object_id]] = - std::make_shared(object_buffers[i].data); - unready.erase(object_id); - } - } - - num_attempts += 1; - // TODO: log a message if attempted too many times. - } - - if (was_blocked) { - RAY_CHECK_OK(core_worker_.raylet_client_->NotifyUnblocked( - core_worker_.worker_context_.GetCurrentTaskID())); - } - - return Status::OK(); + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Get( + ids, timeout_ms, core_worker_.worker_context_.GetCurrentTaskID(), results); } Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, int num_objects, int64_t timeout_ms, std::vector *results) { - WaitResultPair result_pair; - auto status = core_worker_.raylet_client_->Wait( - object_ids, num_objects, timeout_ms, false, - core_worker_.worker_context_.GetCurrentTaskID(), &result_pair); - std::unordered_set ready_ids; - for (const auto &entry : result_pair.first) { - ready_ids.insert(entry); - } - - // TODO: change RayletClient::Wait() to return a bit set, so that we don't need - // to do this translation. - (*results).resize(object_ids.size()); - for (int i = 0; i < object_ids.size(); i++) { - (*results)[i] = ready_ids.count(object_ids[i]) > 0; - } - - return status; + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Wait(object_ids, num_objects, timeout_ms, + core_worker_.worker_context_.GetCurrentTaskID(), + results); } Status CoreWorkerObjectInterface::Delete(const std::vector &object_ids, bool local_only, bool delete_creating_tasks) { - return core_worker_.raylet_client_->FreeObjects(object_ids, local_only, - delete_creating_tasks); + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Delete(object_ids, local_only, delete_creating_tasks); } } // namespace ray diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 8a9e20c48c6e..35403675f164 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -1,15 +1,17 @@ #ifndef RAY_CORE_WORKER_OBJECT_INTERFACE_H #define RAY_CORE_WORKER_OBJECT_INTERFACE_H -#include "common.h" #include "plasma/client.h" #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/store_provider/store_provider.h" namespace ray { class CoreWorker; +class CoreWorkerStoreProvider; /// The interface that contains all `CoreWorker` methods that are related to object store. class CoreWorkerObjectInterface { @@ -23,6 +25,13 @@ class CoreWorkerObjectInterface { /// \return Status. Status Put(const Buffer &buffer, ObjectID *object_id); + /// Put an object with specified ID into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[in] object_id Object ID specified by user. + /// \return Status. + Status Put(const Buffer &buffer, const ObjectID &object_id); + /// Get a list of objects from the object store. /// /// \param[in] ids IDs of the objects to get. @@ -55,6 +64,9 @@ class CoreWorkerObjectInterface { private: /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; + + /// All the store providers supported. + std::unordered_map> store_providers_; }; } // namespace ray diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc new file mode 100644 index 000000000000..b5dd91d82881 --- /dev/null +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -0,0 +1,139 @@ +#include "ray/core_worker/store_provider/plasma_store_provider.h" +#include "ray/common/ray_config.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/object_interface.h" + +namespace ray { + +CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( + plasma::PlasmaClient &store_client, std::mutex &store_client_mutex, + RayletClient &raylet_client) + : store_client_(store_client), + store_client_mutex_(store_client_mutex), + raylet_client_(raylet_client) {} + +Status CoreWorkerPlasmaStoreProvider::Put(const Buffer &buffer, + const ObjectID &object_id) { + auto plasma_id = object_id.ToPlasmaId(); + std::shared_ptr data; + { + std::unique_lock guard(store_client_mutex_); + RAY_ARROW_RETURN_NOT_OK( + store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); + } + + memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); + + { + std::unique_lock guard(store_client_mutex_); + RAY_ARROW_RETURN_NOT_OK(store_client_.Seal(plasma_id)); + RAY_ARROW_RETURN_NOT_OK(store_client_.Release(plasma_id)); + } + return Status::OK(); +} + +Status CoreWorkerPlasmaStoreProvider::Get(const std::vector &ids, + int64_t timeout_ms, const TaskID &task_id, + std::vector> *results) { + (*results).resize(ids.size(), nullptr); + + bool was_blocked = false; + + std::unordered_map unready; + for (size_t i = 0; i < ids.size(); i++) { + unready.insert({ids[i], i}); + } + + int num_attempts = 0; + bool should_break = false; + int64_t remaining_timeout = timeout_ms; + // Repeat until we get all objects. + while (!unready.empty() && !should_break) { + std::vector unready_ids; + for (const auto &entry : unready) { + unready_ids.push_back(entry.first); + } + + // For the initial fetch, we only fetch the objects, do not reconstruct them. + bool fetch_only = num_attempts == 0; + if (!fetch_only) { + // If fetch_only is false, this worker will be blocked. + was_blocked = true; + } + + // TODO(zhijunfu): can call `fetchOrReconstruct` in batches as an optimization. + RAY_CHECK_OK(raylet_client_.FetchOrReconstruct(unready_ids, fetch_only, task_id)); + + // Get the objects from the object store, and parse the result. + int64_t get_timeout; + if (remaining_timeout >= 0) { + get_timeout = + std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds()); + remaining_timeout -= get_timeout; + should_break = remaining_timeout <= 0; + } else { + get_timeout = RayConfig::instance().get_timeout_milliseconds(); + } + + std::vector plasma_ids; + for (const auto &id : unready_ids) { + plasma_ids.push_back(id.ToPlasmaId()); + } + + std::vector object_buffers; + { + std::unique_lock guard(store_client_mutex_); + auto status = store_client_.Get(plasma_ids, get_timeout, &object_buffers); + } + + for (size_t i = 0; i < object_buffers.size(); i++) { + if (object_buffers[i].data != nullptr) { + const auto &object_id = unready_ids[i]; + (*results)[unready[object_id]] = + std::make_shared(object_buffers[i].data); + unready.erase(object_id); + } + } + + num_attempts += 1; + // TODO(zhijunfu): log a message if attempted too many times. + } + + if (was_blocked) { + RAY_CHECK_OK(raylet_client_.NotifyUnblocked(task_id)); + } + + return Status::OK(); +} + +Status CoreWorkerPlasmaStoreProvider::Wait(const std::vector &object_ids, + int num_objects, int64_t timeout_ms, + const TaskID &task_id, + std::vector *results) { + WaitResultPair result_pair; + auto status = raylet_client_.Wait(object_ids, num_objects, timeout_ms, false, task_id, + &result_pair); + std::unordered_set ready_ids; + for (const auto &entry : result_pair.first) { + ready_ids.insert(entry); + } + + // TODO(zhijunfu): change RayletClient::Wait() to return a bit set, so that we don't + // need + // to do this translation. + (*results).resize(object_ids.size()); + for (size_t i = 0; i < object_ids.size(); i++) { + (*results)[i] = ready_ids.count(object_ids[i]) > 0; + } + + return status; +} + +Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector &object_ids, + bool local_only, + bool delete_creating_tasks) { + return raylet_client_.FreeObjects(object_ids, local_only, delete_creating_tasks); +} + +} // namespace ray diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h new file mode 100644 index 000000000000..0dfce1eb1e45 --- /dev/null +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -0,0 +1,76 @@ +#ifndef RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H +#define RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H + +#include "plasma/client.h" +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/store_provider/store_provider.h" +#include "ray/raylet/raylet_client.h" + +namespace ray { + +class CoreWorker; + +/// The class provides implementations for accessing plasma store, which includes both +/// local and remote store, remote access is done via raylet. +class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider { + public: + CoreWorkerPlasmaStoreProvider(plasma::PlasmaClient &store_client, + std::mutex &store_client_mutex, + RayletClient &raylet_client); + + /// Put an object with specified ID into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[in] object_id Object ID specified by user. + /// \return Status. + Status Put(const Buffer &buffer, const ObjectID &object_id) override; + + /// Get a list of objects from the object store. + /// + /// \param[in] ids IDs of the objects to get. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results Result list of objects data. + /// \return Status. + Status Get(const std::vector &ids, int64_t timeout_ms, const TaskID &task_id, + std::vector> *results) override; + + /// Wait for a list of objects to appear in the object store. + /// + /// \param[in] IDs of the objects to wait for. + /// \param[in] num_returns Number of objects that should appear. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results A bitset that indicates each object has appeared or not. + /// \return Status. + Status Wait(const std::vector &object_ids, int num_objects, + int64_t timeout_ms, const TaskID &task_id, + std::vector *results) override; + + /// Delete a list of objects from the object store. + /// + /// \param[in] object_ids IDs of the objects to delete. + /// \param[in] local_only Whether only delete the objects in local node, or all nodes in + /// the cluster. + /// \param[in] delete_creating_tasks Whether also delete the tasks that + /// created these objects. \return Status. + Status Delete(const std::vector &object_ids, bool local_only, + bool delete_creating_tasks) override; + + private: + /// Plasma store client. + plasma::PlasmaClient &store_client_; + + /// Mutex to protect store_client_. + std::mutex &store_client_mutex_; + + /// Raylet client. + RayletClient &raylet_client_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H diff --git a/src/ray/core_worker/store_provider/store_provider.h b/src/ray/core_worker/store_provider/store_provider.h new file mode 100644 index 000000000000..f1521edf1626 --- /dev/null +++ b/src/ray/core_worker/store_provider/store_provider.h @@ -0,0 +1,64 @@ +#ifndef RAY_CORE_WORKER_STORE_PROVIDER_H +#define RAY_CORE_WORKER_STORE_PROVIDER_H + +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/common.h" + +namespace ray { + +/// Provider interface for store access. Store provider should inherit from this class and +/// provide implementions for the methods. The actual store provider may use a plasma +/// store or local memory store in worker process, or possibly other types of storage. + +class CoreWorkerStoreProvider { + public: + CoreWorkerStoreProvider() {} + + virtual ~CoreWorkerStoreProvider() {} + + /// Put an object with specified ID into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[in] object_id Object ID specified by user. + /// \return Status. + virtual Status Put(const Buffer &buffer, const ObjectID &object_id) = 0; + + /// Get a list of objects from the object store. + /// + /// \param[in] ids IDs of the objects to get. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results Result list of objects data. + /// \return Status. + virtual Status Get(const std::vector &ids, int64_t timeout_ms, + const TaskID &task_id, + std::vector> *results) = 0; + + /// Wait for a list of objects to appear in the object store. + /// + /// \param[in] IDs of the objects to wait for. + /// \param[in] num_returns Number of objects that should appear. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results A bitset that indicates each object has appeared or not. + /// \return Status. + virtual Status Wait(const std::vector &object_ids, int num_objects, + int64_t timeout_ms, const TaskID &task_id, + std::vector *results) = 0; + + /// Delete a list of objects from the object store. + /// + /// \param[in] object_ids IDs of the objects to delete. + /// \param[in] local_only Whether only delete the objects in local node, or all nodes in + /// the cluster. + /// \param[in] delete_creating_tasks Whether also delete the tasks that + /// created these objects. \return Status. + virtual Status Delete(const std::vector &object_ids, bool local_only, + bool delete_creating_tasks) = 0; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_STORE_PROVIDER_H diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index aea48b4de34a..701ae3124c97 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -1,7 +1,91 @@ -#include "task_execution.h" +#include "ray/core_worker/task_execution.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/transport/raylet_transport.h" namespace ray { -void CoreWorkerTaskExecutionInterface::Start(const TaskExecutor &executor) {} +CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( + CoreWorker &core_worker) + : core_worker_(core_worker) { + task_receivers.emplace( + static_cast(TaskTransportType::RAYLET), + std::unique_ptr( + new CoreWorkerRayletTaskReceiver(core_worker_.raylet_client_))); +} + +Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) { + while (true) { + std::vector tasks; + auto status = + task_receivers[static_cast(TaskTransportType::RAYLET)]->GetTasks(&tasks); + if (!status.ok()) { + RAY_LOG(ERROR) << "Getting task failed with error: " + << ray::Status::IOError(status.message()); + return status; + } + + for (const auto &task : tasks) { + const auto &spec = task.GetTaskSpecification(); + core_worker_.worker_context_.SetCurrentTask(spec); + + WorkerLanguage language = (spec.GetLanguage() == ::Language::JAVA) + ? WorkerLanguage::JAVA + : WorkerLanguage::PYTHON; + RayFunction func{language, spec.FunctionDescriptor()}; + + std::vector> args; + RAY_CHECK_OK(BuildArgsForExecutor(spec, &args)); + + auto num_returns = spec.NumReturns(); + if (spec.IsActorCreationTask() || spec.IsActorTask()) { + RAY_CHECK(num_returns > 0); + // Decrease to account for the dummy object id. + num_returns--; + } + + status = executor(func, args, spec.TaskId(), num_returns); + // TODO(zhijunfu): + // 1. Check and handle failure. + // 2. Save or load checkpoint. + } + } + + // should never reach here. + return Status::OK(); +} + +Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor( + const raylet::TaskSpecification &spec, std::vector> *args) { + auto num_args = spec.NumArgs(); + (*args).resize(num_args); + + std::vector object_ids_to_fetch; + std::vector indices; + + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + if (count > 0) { + // pass by reference. + RAY_CHECK(count == 1); + object_ids_to_fetch.push_back(spec.ArgId(i, 0)); + indices.push_back(i); + } else { + // pass by value. + (*args)[i] = std::make_shared( + const_cast(spec.ArgVal(i)), spec.ArgValLength(i)); + } + } + + std::vector> results; + auto status = core_worker_.object_interface_.Get(object_ids_to_fetch, -1, &results); + if (status.ok()) { + for (size_t i = 0; i < results.size(); i++) { + (*args)[indices[i]] = results[i]; + } + } + + return status; +} } // namespace ray diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h index c4de937ee439..f4b44b9e131d 100644 --- a/src/ray/core_worker/task_execution.h +++ b/src/ray/core_worker/task_execution.h @@ -1,34 +1,54 @@ #ifndef RAY_CORE_WORKER_TASK_EXECUTION_H #define RAY_CORE_WORKER_TASK_EXECUTION_H -#include "common.h" #include "ray/common/buffer.h" #include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/transport/transport.h" namespace ray { class CoreWorker; +namespace raylet { +class TaskSpecification; +} + /// The interface that contains all `CoreWorker` methods that are related to task /// execution. class CoreWorkerTaskExecutionInterface { public: - CoreWorkerTaskExecutionInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} - + CoreWorkerTaskExecutionInterface(CoreWorker &core_worker); /// The callback provided app-language workers that executes tasks. /// /// \param ray_function[in] Information about the function to execute. /// \param args[in] Arguments of the task. /// \return Status. - using TaskExecutor = std::function &args)>; + using TaskExecutor = std::function> &args, + const TaskID &task_id, int num_returns)>; /// Start receving and executes tasks in a infinite loop. - void Start(const TaskExecutor &executor); + /// \return Status. + Status Run(const TaskExecutor &executor); private: + /// Build arguments for task executor. This would loop through all the arguments + /// in task spec, and for each of them that's passed by reference (ObjectID), + /// fetch its content from store and; for arguments that are passed by value, + /// just copy their content. + /// + /// \param spec[in] Task specification. + /// \param args[out] The arguments for passing to task executor. + /// + Status BuildArgsForExecutor(const raylet::TaskSpecification &spec, + std::vector> *args); + /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; + + /// All the task task receivers supported. + std::unordered_map> task_receivers; }; } // namespace ray diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index ab8b8950c298..6a91bd6b2101 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -1,18 +1,80 @@ -#include "task_interface.h" +#include "ray/core_worker/task_interface.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/task_interface.h" +#include "ray/core_worker/transport/raylet_transport.h" namespace ray { +CoreWorkerTaskInterface::CoreWorkerTaskInterface(CoreWorker &core_worker) + : core_worker_(core_worker) { + task_submitters_.emplace( + static_cast(TaskTransportType::RAYLET), + std::unique_ptr( + new CoreWorkerRayletTaskSubmitter(core_worker_.raylet_client_))); +} + Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, const std::vector &args, const TaskOptions &task_options, std::vector *return_ids) { - return Status::OK(); + auto &context = core_worker_.worker_context_; + auto next_task_index = context.GetNextTaskIndex(); + const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index); + + auto num_returns = task_options.num_returns; + (*return_ids).resize(num_returns); + for (int i = 0; i < num_returns; i++) { + (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); + } + + auto task_arguments = BuildTaskArguments(args); + auto language = core_worker_.ToTaskLanguage(function.language); + + ray::raylet::TaskSpecification spec(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index, + task_arguments, num_returns, task_options.resources, + language, function.function_descriptor); + + std::vector execution_dependencies; + TaskSpec task(std::move(spec), execution_dependencies); + return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); } Status CoreWorkerTaskInterface::CreateActor( const RayFunction &function, const std::vector &args, - const ActorCreationOptions &actor_creation_options, ActorHandle *actor_handle) { - return Status::OK(); + const ActorCreationOptions &actor_creation_options, + std::unique_ptr *actor_handle) { + auto &context = core_worker_.worker_context_; + auto next_task_index = context.GetNextTaskIndex(); + const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index); + + std::vector return_ids; + return_ids.push_back(ObjectID::ForTaskReturn(task_id, 1)); + ActorID actor_creation_id = ActorID::FromBinary(return_ids[0].Binary()); + + *actor_handle = std::unique_ptr( + new ActorHandle(actor_creation_id, ActorHandleID::Nil())); + (*actor_handle)->IncreaseTaskCounter(); + (*actor_handle)->SetActorCursor(return_ids[0]); + + auto task_arguments = BuildTaskArguments(args); + auto language = core_worker_.ToTaskLanguage(function.language); + + // Note that the caller is supposed to specify required placement resources + // correctly via actor_creation_options.resources. + ray::raylet::TaskSpecification spec( + context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, + actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions, + ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1, + actor_creation_options.resources, actor_creation_options.resources, language, + function.function_descriptor); + + std::vector execution_dependencies; + TaskSpec task(std::move(spec), execution_dependencies); + return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); } Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, @@ -20,7 +82,63 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, const std::vector &args, const TaskOptions &task_options, std::vector *return_ids) { - return Status::OK(); + auto &context = core_worker_.worker_context_; + auto next_task_index = context.GetNextTaskIndex(); + const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index); + + // add one for actor cursor object id. + auto num_returns = task_options.num_returns + 1; + (*return_ids).resize(num_returns); + for (int i = 0; i < num_returns; i++) { + (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); + } + + auto actor_creation_dummy_object_id = + ObjectID::FromBinary(actor_handle.ActorID().Binary()); + + auto task_arguments = BuildTaskArguments(args); + auto language = core_worker_.ToTaskLanguage(function.language); + + std::vector new_actor_handles; + ray::raylet::TaskSpecification spec( + context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, + ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(), + actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(), new_actor_handles, + task_arguments, num_returns, task_options.resources, task_options.resources, + language, function.function_descriptor); + + std::vector execution_dependencies; + execution_dependencies.push_back(actor_handle.ActorCursor()); + + auto actor_cursor = (*return_ids).back(); + actor_handle.SetActorCursor(actor_cursor); + actor_handle.ClearNewActorHandles(); + + TaskSpec task(std::move(spec), execution_dependencies); + auto status = + task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); + + // remove cursor from return ids. + (*return_ids).pop_back(); + return status; +} + +std::vector> +CoreWorkerTaskInterface::BuildTaskArguments(const std::vector &args) { + std::vector> task_arguments; + for (const auto &arg : args) { + if (arg.IsPassedByReference()) { + std::vector references{arg.GetReference()}; + task_arguments.push_back( + std::make_shared(references)); + } else { + auto data = arg.GetValue(); + task_arguments.push_back( + std::make_shared(data->Data(), data->Size())); + } + } + return task_arguments; } } // namespace ray diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index e23f049d341d..e59934f9b51d 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -1,10 +1,14 @@ #ifndef RAY_CORE_WORKER_TASK_INTERFACE_H #define RAY_CORE_WORKER_TASK_INTERFACE_H -#include "common.h" +#include + #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/transport/transport.h" +#include "ray/raylet/task.h" namespace ray { @@ -12,6 +16,10 @@ class CoreWorker; /// Options of a non-actor-creation task. struct TaskOptions { + TaskOptions() {} + TaskOptions(int num_returns, const std::unordered_map &resources) + : num_returns(num_returns), resources(resources) {} + /// Number of returns of this task. const int num_returns = 1; /// Resources required by this task. @@ -20,6 +28,11 @@ struct TaskOptions { /// Options of an actor creation task. struct ActorCreationOptions { + ActorCreationOptions() {} + ActorCreationOptions(uint64_t max_reconstructions, + const std::unordered_map &resources) + : max_reconstructions(max_reconstructions), resources(resources) {} + /// Maximum number of times that the actor should be reconstructed when it dies /// unexpectedly. It must be non-negative. If it's 0, the actor won't be reconstructed. const uint64_t max_reconstructions = 0; @@ -31,26 +44,53 @@ struct ActorCreationOptions { class ActorHandle { public: ActorHandle(const ActorID &actor_id, const ActorHandleID &actor_handle_id) - : actor_id_(actor_id), actor_handle_id_(actor_handle_id) {} + : actor_id_(actor_id), + actor_handle_id_(actor_handle_id), + actor_cursor_(ObjectID::FromBinary(actor_id.Binary())), + task_counter_(0) {} /// ID of the actor. - const class ActorID &ActorID() const { return actor_id_; } + const ray::ActorID &ActorID() const { return actor_id_; }; /// ID of this actor handle. - const class ActorHandleID &ActorHandleID() const { return actor_handle_id_; } + const ray::ActorHandleID &ActorHandleID() const { return actor_handle_id_; }; + + private: + /// Cursor of this actor. + const ObjectID &ActorCursor() const { return actor_cursor_; }; + + /// Set actor cursor. + void SetActorCursor(const ObjectID &actor_cursor) { actor_cursor_ = actor_cursor; }; + + /// Increase task counter. + int IncreaseTaskCounter() { return task_counter_++; } + + std::list GetNewActorHandle() { + // TODO(zhijunfu): implement this. + return std::list(); + } + + void ClearNewActorHandles() { /* TODO(zhijunfu): implement this. */ + } private: /// ID of the actor. - const class ActorID actor_id_; + const ray::ActorID actor_id_; /// ID of this actor handle. - const class ActorHandleID actor_handle_id_; + const ray::ActorHandleID actor_handle_id_; + /// ID of this actor cursor. + ObjectID actor_cursor_; + /// Counter for tasks from this handle. + int task_counter_; + + friend class CoreWorkerTaskInterface; }; /// The interface that contains all `CoreWorker` methods that are related to task /// submission. class CoreWorkerTaskInterface { public: - CoreWorkerTaskInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + CoreWorkerTaskInterface(CoreWorker &core_worker); /// Submit a normal task. /// @@ -71,7 +111,7 @@ class CoreWorkerTaskInterface { /// \return Status. Status CreateActor(const RayFunction &function, const std::vector &args, const ActorCreationOptions &actor_creation_options, - ActorHandle *actor_handle); + std::unique_ptr *actor_handle); /// Submit an actor task. /// @@ -89,6 +129,17 @@ class CoreWorkerTaskInterface { private: /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; + + private: + /// Build the arguments for a task spec. + /// + /// \param[in] args Arguments of a task. + /// \return Arguments as required by task spec. + std::vector> BuildTaskArguments( + const std::vector &args); + + /// All the task submitters supported. + std::unordered_map> task_submitters_; }; } // namespace ray diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc new file mode 100644 index 000000000000..14906acfe0bf --- /dev/null +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -0,0 +1,32 @@ + +#include "ray/core_worker/transport/raylet_transport.h" + +namespace ray { + +CoreWorkerRayletTaskSubmitter::CoreWorkerRayletTaskSubmitter(RayletClient &raylet_client) + : raylet_client_(raylet_client) {} + +Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpec &task) { + return raylet_client_.SubmitTask(task.GetDependencies(), task.GetTaskSpecification()); +} + +CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(RayletClient &raylet_client) + : raylet_client_(raylet_client) {} + +Status CoreWorkerRayletTaskReceiver::GetTasks(std::vector *tasks) { + std::unique_ptr task_spec; + auto status = raylet_client_.GetTask(&task_spec); + if (!status.ok()) { + RAY_LOG(ERROR) << "Get task from raylet failed with error: " + << ray::Status::IOError(status.message()); + return status; + } + + std::vector dependencies; + RAY_CHECK((*tasks).empty()); + (*tasks).emplace_back(*task_spec, dependencies); + + return Status::OK(); +} + +} // namespace ray diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h new file mode 100644 index 000000000000..03bf82f29886 --- /dev/null +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -0,0 +1,44 @@ +#ifndef RAY_CORE_WORKER_RAYLET_TRANSPORT_H +#define RAY_CORE_WORKER_RAYLET_TRANSPORT_H + +#include + +#include "ray/core_worker/transport/transport.h" +#include "ray/raylet/raylet_client.h" + +namespace ray { + +/// In raylet task submitter and receiver, a task is submitted to raylet, and possibly +/// gets forwarded to another raylet on which node the task should be executed, and +/// then a worker on that node gets this task and starts executing it. + +class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter { + public: + CoreWorkerRayletTaskSubmitter(RayletClient &raylet_client); + + /// Submit a task for execution to raylet. + /// + /// \param[in] task The task spec to submit. + /// \return Status. + virtual Status SubmitTask(const TaskSpec &task) override; + + private: + /// Raylet client. + RayletClient &raylet_client_; +}; + +class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver { + public: + CoreWorkerRayletTaskReceiver(RayletClient &raylet_client); + + // Get tasks for execution from raylet. + virtual Status GetTasks(std::vector *tasks) override; + + private: + /// Raylet client. + RayletClient &raylet_client_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_RAYLET_TRANSPORT_H diff --git a/src/ray/core_worker/transport/transport.h b/src/ray/core_worker/transport/transport.h new file mode 100644 index 000000000000..44be74b989c7 --- /dev/null +++ b/src/ray/core_worker/transport/transport.h @@ -0,0 +1,41 @@ +#ifndef RAY_CORE_WORKER_TRANSPORT_H +#define RAY_CORE_WORKER_TRANSPORT_H + +#include + +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/raylet/task_spec.h" + +namespace ray { + +/// Interfaces for task submitter and receiver. They are separate classes but should be +/// used in pairs - one type of task submitter should be used together with task +/// with the same type, so these classes are put together in this same file. +/// +/// Task submitter/receiver should inherit from these classes and provide implementions +/// for the methods. The actual task submitter/receiver can submit/get tasks via raylet, +/// or directly to/from another worker. + +/// This class is responsible to submit tasks. +class CoreWorkerTaskSubmitter { + public: + /// Submit a task for execution. + /// + /// \param[in] task The task spec to submit. + /// \return Status. + virtual Status SubmitTask(const TaskSpec &task) = 0; +}; + +/// This class receives tasks for execution. +class CoreWorkerTaskReceiver { + public: + // Get tasks for execution. + virtual Status GetTasks(std::vector *tasks) = 0; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_TRANSPORT_H diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index d9b5087c4719..c9b1e138575d 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -23,8 +23,8 @@ static void GetRedisShards(redisContext *context, std::vector &addr } RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) << "No entry found for NumRedisShards"; - RAY_CHECK(reply->type == REDIS_REPLY_STRING) << "Expected string, found Redis type " - << reply->type << " for NumRedisShards"; + RAY_CHECK(reply->type == REDIS_REPLY_STRING) + << "Expected string, found Redis type " << reply->type << " for NumRedisShards"; int num_redis_shards = atoi(reply->str); RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, " << "found " << num_redis_shards; @@ -120,6 +120,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, profile_table_.reset(new ProfileTable(shard_contexts_, this)); actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this)); actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this)); + resource_table_.reset(new DynamicResourceTable({primary_context_}, this)); command_type_ = command_type; // TODO(swang): Call the client table's Connect() method here. To do this, @@ -229,6 +230,8 @@ ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() { return *actor_checkpoint_id_table_; } +DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; } + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index d47d9a6e8b24..c9f5b4bca624 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -62,6 +62,7 @@ class RAY_EXPORT AsyncGcsClient { ProfileTable &profile_table(); ActorCheckpointTable &actor_checkpoint_table(); ActorCheckpointIdTable &actor_checkpoint_id_table(); + DynamicResourceTable &resource_table(); // We also need something to export generic code to run on workers from the // driver (to set the PYTHONPATH) @@ -94,6 +95,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr client_table_; std::unique_ptr actor_checkpoint_table_; std::unique_ptr actor_checkpoint_id_table_; + std::unique_ptr resource_table_; // The following contexts write to the data shard std::vector> shard_contexts_; std::vector> shard_asio_async_clients_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 1b43bcc23c08..c7dc02e50651 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -150,8 +150,8 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( - gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); @@ -241,8 +241,8 @@ void TestLogAppendAt(const DriverID &driver_id, /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( - gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { appended_managers.push_back(entry.node_manager_id); @@ -282,8 +282,8 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that lookup returns the added object entries. auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -296,8 +296,9 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli auto data = std::make_shared(); data->manager = manager; // Check that we added the correct object entries. - auto remove_entry_callback = [object_id, data]( - gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { + auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, + const ObjectID &id, + const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); @@ -308,8 +309,8 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -350,8 +351,8 @@ void TestDeleteKeysFromLog( for (const auto &task_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( - gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -445,8 +446,8 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, for (const auto &object_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -657,13 +658,13 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, - const std::vector data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsChangeMode change_mode, + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE); + ASSERT_EQ(change_mode, GcsChangeMode::REMOVE); } ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. @@ -737,8 +738,9 @@ void TestTableSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id2, task_specs2]( - gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { + auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, + const TaskID &id, + const protocol::TaskT &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. @@ -752,7 +754,7 @@ void TestTableSubscribeId(const DriverID &driver_id, // The failure callback should be called once since both keys start as empty. bool failure_notification_received = false; auto failure_callback = [task_id2, &failure_notification_received]( - gcs::AsyncGcsClient *client, const TaskID &id) { + gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id2); // The failure notification should be the first notification received. ASSERT_EQ(test->NumCallbacks(), 0); @@ -820,8 +822,8 @@ void TestLogSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( - gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. @@ -894,10 +896,10 @@ void TestSetSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [object_id2, managers2]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, - const std::vector &data) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsChangeMode change_mode, + const std::vector &data) { + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. @@ -968,8 +970,9 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id, task_specs]( - gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { + auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, + const TaskID &id, + const protocol::TaskT &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. @@ -1038,8 +1041,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1111,10 +1114,10 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, - const std::vector &data) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsChangeMode change_mode, + const std::vector &data) { + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because notifications @@ -1294,11 +1297,12 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id)); // Make sure we only get a notification for the removal of the client we // marked as dead. - client->client_table().RegisterClientRemovedCallback([dead_client_id]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); - test->Stop(); - }); + client->client_table().RegisterClientRemovedCallback( + [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + test->Stop(); + }); test->Start(); } @@ -1307,6 +1311,162 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { TestClientTableMarkDisconnected(driver_id_, client_); } +void TestHashTable(const DriverID &driver_id, + std::shared_ptr client) { + const int expected_count = 14; + ClientID client_id = ClientID::FromRandom(); + // Prepare the first resource map: data_map1. + auto cpu_data = std::make_shared(); + cpu_data->resource_name = "CPU"; + cpu_data->resource_capacity = 100; + auto gpu_data = std::make_shared(); + gpu_data->resource_name = "GPU"; + gpu_data->resource_capacity = 2; + DynamicResourceTable::DataMap data_map1; + data_map1.emplace("CPU", cpu_data); + data_map1.emplace("GPU", gpu_data); + // Prepare the second resource map: data_map2 which decreases CPU, + // increases GPU and add a new CUSTOM compared to data_map1. + auto data_cpu = std::make_shared(); + data_cpu->resource_name = "CPU"; + data_cpu->resource_capacity = 50; + auto data_gpu = std::make_shared(); + data_gpu->resource_name = "GPU"; + data_gpu->resource_capacity = 10; + auto data_custom = std::make_shared(); + data_custom->resource_name = "CUSTOM"; + data_custom->resource_capacity = 2; + DynamicResourceTable::DataMap data_map2; + data_map2.emplace("CPU", data_cpu); + data_map2.emplace("GPU", data_gpu); + data_map2.emplace("CUSTOM", data_custom); + data_map2["CPU"]->resource_capacity = 50; + // This is a common comparison function for the test. + auto compare_test = [](const DynamicResourceTable::DataMap &data1, + const DynamicResourceTable::DataMap &data2) { + ASSERT_EQ(data1.size(), data2.size()); + for (const auto &data : data1) { + auto iter = data2.find(data.first); + ASSERT_TRUE(iter != data2.end()); + ASSERT_EQ(iter->second->resource_name, data.second->resource_name); + ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + } + }; + auto subscribe_callback = [](AsyncGcsClient *client) { + ASSERT_TRUE(true); + test->IncrementNumCallbacks(); + }; + auto notification_callback = [data_map1, data_map2, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const GcsChangeMode change_mode, + const DynamicResourceTable::DataMap &data) { + if (change_mode == GcsChangeMode::REMOVE) { + ASSERT_EQ(data.size(), 2); + ASSERT_TRUE(data.find("GPU") != data.end()); + ASSERT_TRUE(data.find("CUSTOM") != data.end() || data.find("CPU") != data.end()); + // The key "None-Existent" will not appear in the notification. + } else { + if (data.size() == 2) { + compare_test(data_map1, data); + } else if (data.size() == 3) { + compare_test(data_map2, data); + } else { + ASSERT_TRUE(false); + } + } + test->IncrementNumCallbacks(); + // It is not sure which of the notification or lookup callback will come first. + if (test->NumCallbacks() == expected_count) { + test->Stop(); + } + }; + // Step 0: Subscribe the change of the hash table. + RAY_CHECK_OK(client->resource_table().Subscribe( + driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->resource_table().RequestNotifications( + driver_id, client_id, client->client_table().GetLocalClientId())); + + // Step 1: Add elements to the hash table. + auto update_callback1 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK( + client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + auto lookup_callback1 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1)); + + // Step 2: Decrease one element, increase one and add a new one. + RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr)); + auto lookup_callback2 = [data_map2, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map2, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2)); + std::vector delete_keys({"GPU", "CUSTOM", "None-Existent"}); + auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id, + const std::vector &callback_data) { + for (int i = 0; i < callback_data.size(); ++i) { + // All deleting keys exist in this argument even if the key doesn't exist. + ASSERT_EQ(callback_data[i], delete_keys[i]); + } + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys, + remove_callback)); + DynamicResourceTable::DataMap data_map3(data_map2); + data_map3.erase("GPU"); + data_map3.erase("CUSTOM"); + auto lookup_callback3 = [data_map3, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map3, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3)); + + // Step 3: Reset the the resources to data_map1. + RAY_CHECK_OK( + client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + auto lookup_callback4 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4)); + + // Step 4: Removing all elements will remove the home Hash table from GCS. + RAY_CHECK_OK(client->resource_table().RemoveEntries( + driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); + auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + ASSERT_EQ(callback_data.size(), 0); + test->IncrementNumCallbacks(); + // It is not sure which of notification or lookup callback will come first. + if (test->NumCallbacks() == expected_count) { + test->Stop(); + } + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5)); + test->Start(); + ASSERT_EQ(test->NumCallbacks(), expected_count); +} + +TEST_F(TestGcsWithAsio, TestHashTable) { + test = this; + TestHashTable(driver_id_, client_); +} + #undef TEST_MACRO } // namespace gcs diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index b81f388d88c5..614c80b27672 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -22,6 +22,7 @@ enum TablePrefix:int { TASK_LEASE, ACTOR_CHECKPOINT, ACTOR_CHECKPOINT_ID, + NODE_RESOURCE, } // The channel that Add operations to the Table should be published on, if any. @@ -37,6 +38,7 @@ enum TablePubsub:int { ERROR_INFO, TASK_LEASE, DRIVER, + NODE_RESOURCE, } // Enum for the entry type in the ClientTable @@ -113,13 +115,13 @@ table ResourcePair { value: double; } -enum GcsTableNotificationMode:int { +enum GcsChangeMode:int { APPEND_OR_ADD = 0, REMOVE, } -table GcsTableEntry { - notification_mode: GcsTableNotificationMode; +table GcsEntry { + change_mode: GcsChangeMode; id: string; entries: [string]; } diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index fe5ba3d1d134..ae6cb6088cec 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -48,28 +48,28 @@ namespace gcs { CallbackReply::CallbackReply(redisReply *redis_reply) { RAY_CHECK(nullptr != redis_reply); - RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) << "Got an error in redis reply: " - << redis_reply->str; + RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) + << "Got an error in redis reply: " << redis_reply->str; this->redis_reply_ = redis_reply; } bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == redis_reply_->type; } int64_t CallbackReply::ReadAsInteger() const { - RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; return static_cast(redis_reply_->integer); } std::string CallbackReply::ReadAsString() const { - RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; return std::string(redis_reply_->str, redis_reply_->len); } Status CallbackReply::ReadAsStatus() const { - RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; const std::string status_str(redis_reply_->str, redis_reply_->len); if ("OK" == status_str) { return Status::OK(); @@ -79,8 +79,8 @@ Status CallbackReply::ReadAsStatus() const { } std::string CallbackReply::ReadAsPubsubData() const { - RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; std::string data = ""; // Parse the published message. diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 23e611e400df..e291b7ffdb32 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -179,32 +179,20 @@ flatbuffers::Offset RedisStringToFlatbuf( return fbb.CreateString(redis_string_str, redis_string_size); } -/// Publish a notification for an entry update at a key. This publishes a -/// notification to all subscribers of the table, as well as every client that -/// has requested notifications for this key. +/// Helper method to publish formatted data to target channel. /// /// \param pubsub_channel_str The pubsub channel name that notifications for /// this key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key that the notification is about. -/// \param mode the update mode, such as append or remove. -/// \param data The appended/removed data. +/// \param data_buffer The data to publish, which is a GcsEntry buffer. /// \return OK if there is no error during a publish. -int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, - RedisModuleString *id, GcsTableNotificationMode notification_mode, - RedisModuleString *data) { - // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = - CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - +int PublishDataHelper(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, RedisModuleString *data_buffer) { // Write the data back to any subscribers that are listening to all table // notifications. - RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, - fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data_buffer); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -221,8 +209,8 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st // will be garbage collected by redis. auto channel = RedisModule_CreateString(ctx, client_channel.data(), client_channel.size()); - RedisModuleCallReply *reply = RedisModule_Call( - ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", channel, data_buffer); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -231,6 +219,31 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return RedisModule_ReplyWithSimpleString(ctx, "OK"); } +/// Publish a notification for an entry update at a key. This publishes a +/// notification to all subscribers of the table, as well as every client that +/// has requested notifications for this key. +/// +/// \param pubsub_channel_str The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key that the notification is about. +/// \param mode the update mode, such as append or remove. +/// \param data The appended/removed data. +/// \return OK if there is no error during a publish. +int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, GcsChangeMode change_mode, + RedisModuleString *data) { + // Serialize the notification to send. + flatbuffers::FlatBufferBuilder fbb; + auto data_flatbuf = RedisStringToFlatbuf(fbb, data); + auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), + fbb.CreateVector(&data_flatbuf, 1)); + fbb.Finish(message); + auto data_buffer = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); +} + // RAY.TABLE_ADD: // TableAdd_RedisCommand: the actual command handler. // (helper) TableAdd_DoWrite: performs the write to redis state. @@ -266,8 +279,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - GcsTableNotificationMode::APPEND_OR_ADD, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD, + data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -366,8 +379,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - GcsTableNotificationMode::APPEND_OR_ADD, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD, + data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -419,10 +432,9 @@ int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - is_add ? GcsTableNotificationMode::APPEND_OR_ADD - : GcsTableNotificationMode::REMOVE, - data); + return PublishTableUpdate( + ctx, pubsub_channel_str, id, + is_add ? GcsChangeMode::APPEND_OR_ADD : GcsChangeMode::REMOVE, data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -518,7 +530,125 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar return RedisModule_ReplyWithSimpleString(ctx, "OK"); } -/// A helper function to create and finish a GcsTableEntry, based on the +int Hash_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv) { + RedisModuleString *pubsub_channel_str = argv[2]; + RedisModuleString *id = argv[3]; + RedisModuleString *data = argv[4]; + // Publish a message on the requested pubsub channel if necessary. + TablePubsub pubsub_channel; + REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + if (pubsub_channel != TablePubsub::NO_PUBLISH) { + // All other pubsub channels write the data back directly onto the + // channel. + return PublishDataHelper(ctx, pubsub_channel_str, id, data); + } else { + return RedisModule_ReplyWithSimpleString(ctx, "OK"); + } +} + +/// Do the hash table write operation. This is called from by HashUpdate_RedisCommand. +/// +/// \param change_mode Output the mode of the operation: APPEND_OR_ADD or REMOVE. +/// \param deleted_data Output data if the deleted data is not the same as required. +int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, + GcsChangeMode *change_mode, RedisModuleString **changed_data) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + RedisModuleString *prefix_str = argv[1]; + RedisModuleString *id = argv[3]; + RedisModuleString *update_data = argv[4]; + + RedisModuleKey *key; + REPLY_AND_RETURN_IF_NOT_OK(OpenPrefixedKey( + &key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, nullptr)); + int type = RedisModule_KeyType(key); + REPLY_AND_RETURN_IF_FALSE( + type == REDISMODULE_KEYTYPE_HASH || type == REDISMODULE_KEYTYPE_EMPTY, + "HashUpdate_DoWrite: entries must be a hash or an empty hash"); + + size_t update_data_len = 0; + const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); + + auto data_vec = flatbuffers::GetRoot(update_data_buf); + *change_mode = data_vec->change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { + // This code path means they are updating command. + size_t total_size = data_vec->entries()->size(); + REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); + for (int i = 0; i < total_size; i += 2) { + // Reconstruct a key-value pair from a flattened list. + RedisModuleString *entry_key = RedisModule_CreateString( + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + RedisModuleString *entry_value = + RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), + data_vec->entries()->Get(i + 1)->size()); + // Returning 0 if key exists(still updated), 1 if the key is created. + RAY_IGNORE_EXPR( + RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); + } + *changed_data = update_data; + } else { + // This code path means the command wants to remove the entries. + size_t total_size = data_vec->entries()->size(); + flatbuffers::FlatBufferBuilder fbb; + std::vector> data; + for (int i = 0; i < total_size; i++) { + RedisModuleString *entry_key = RedisModule_CreateString( + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, + REDISMODULE_HASH_DELETE, NULL); + if (deleted_num != 0) { + // The corresponding key is removed. + data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), + data_vec->entries()->Get(i)->size())); + } + } + auto message = + CreateGcsEntry(fbb, data_vec->change_mode(), + fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), + fbb.CreateVector(data)); + fbb.Finish(message); + *changed_data = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + auto size = RedisModule_ValueLength(key); + if (size == 0) { + REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, + "ERR Failed to delete empty hash."); + } + } + return REDISMODULE_OK; +} + +/// Update entries for a hash table. +/// +/// This is called from a client with the command: +// +/// RAY.HASH_UPDATE +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for this +/// key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key to remove from. +/// \param data The GcsEntry flatbugger data used to update this hash table. +/// 1). For deletion, this is a list of keys. +/// 2). For updating, this is a list of pairs with each key followed by the value. +/// \return OK if the remove succeeds, or an error message string if the remove +/// fails. +int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + GcsChangeMode mode; + RedisModuleString *changed_data = nullptr; + if (HashUpdate_DoWrite(ctx, argv, argc, &mode, &changed_data) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + // Replace the data with the changed data to do the publish. + std::vector new_argv(argv, argv + argc); + new_argv[4] = changed_data; + return Hash_DoPublish(ctx, new_argv.data()); +} + +/// A helper function to create and finish a GcsEntry, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -528,7 +658,7 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsTableEntry. +/// \param fbb A flatbuffer builder used to build the GcsEntry. Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, RedisModuleString *prefix_str, RedisModuleString *entry_id, flatbuffers::FlatBufferBuilder &fbb) { @@ -539,12 +669,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); auto data = fbb.CreateString(data_buf, data_len); - auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(&data, 1)); + auto message = + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_LIST: + case REDISMODULE_KEYTYPE_HASH: case REDISMODULE_KEYTYPE_SET: { RedisModule_CloseKey(table_key); // Close the key before executing the command. NOTE(swang): According to @@ -561,10 +692,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, case REDISMODULE_KEYTYPE_SET: reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str); break; + case REDISMODULE_KEYTYPE_HASH: + reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); + break; } // Build the flatbuffer from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { - return Status::RedisError("Empty list or wrong type"); + return Status::RedisError("Empty list/set/hash or wrong type"); } std::vector> data; for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { @@ -574,13 +708,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, data.push_back(fbb.CreateString(element_str, len)); } auto message = - CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsTableEntry( - fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), + auto message = CreateGcsEntry( + fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(std::vector>())); fbb.Finish(message); } break; @@ -637,6 +771,7 @@ static Status DeleteKeyHelper(RedisModuleCtx *ctx, RedisModuleString *prefix_str return Status::RedisError("Key does not exist."); } auto key_type = RedisModule_KeyType(delete_key); + // Set/Hash will delete itself when the length is 0. if (key_type == REDISMODULE_KEYTYPE_STRING || key_type == REDISMODULE_KEYTYPE_LIST) { // Current Table or Log only has this two types of entries. RAY_RETURN_NOT_OK( @@ -869,10 +1004,11 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int std::string debug_string = DebugString(); return RedisModule_ReplyWithStringBuffer(ctx, debug_string.data(), debug_string.size()); } -}; +}; // namespace internal_redis_commands // Wrap all Redis commands with Redis' auto memory management. AUTO_MEMORY(TableAdd_RedisCommand); +AUTO_MEMORY(HashUpdate_RedisCommand); AUTO_MEMORY(TableAppend_RedisCommand); AUTO_MEMORY(SetAdd_RedisCommand); AUTO_MEMORY(SetRemove_RedisCommand); @@ -929,6 +1065,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.hash_update", HashUpdate_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications", TableRequestNotifications_RedisCommand, "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { diff --git a/src/ray/gcs/redis_module/redismodule.h b/src/ray/gcs/redis_module/redismodule.h index 186e284c0e04..23721beef07d 100644 --- a/src/ray/gcs/redis_module/redismodule.h +++ b/src/ray/gcs/redis_module/redismodule.h @@ -1,9 +1,9 @@ #ifndef REDISMODULE_H #define REDISMODULE_H -#include #include #include +#include /* ---------------- Defines common between core and modules --------------- */ @@ -15,8 +15,8 @@ #define REDISMODULE_APIVER_1 1 /* API flags and constants */ -#define REDISMODULE_READ (1<<0) -#define REDISMODULE_WRITE (1<<1) +#define REDISMODULE_READ (1 << 0) +#define REDISMODULE_WRITE (1 << 1) #define REDISMODULE_LIST_HEAD 0 #define REDISMODULE_LIST_TAIL 1 @@ -45,30 +45,31 @@ #define REDISMODULE_NO_EXPIRE -1 /* Sorted set API flags. */ -#define REDISMODULE_ZADD_XX (1<<0) -#define REDISMODULE_ZADD_NX (1<<1) -#define REDISMODULE_ZADD_ADDED (1<<2) -#define REDISMODULE_ZADD_UPDATED (1<<3) -#define REDISMODULE_ZADD_NOP (1<<4) +#define REDISMODULE_ZADD_XX (1 << 0) +#define REDISMODULE_ZADD_NX (1 << 1) +#define REDISMODULE_ZADD_ADDED (1 << 2) +#define REDISMODULE_ZADD_UPDATED (1 << 3) +#define REDISMODULE_ZADD_NOP (1 << 4) /* Hash API flags. */ -#define REDISMODULE_HASH_NONE 0 -#define REDISMODULE_HASH_NX (1<<0) -#define REDISMODULE_HASH_XX (1<<1) -#define REDISMODULE_HASH_CFIELDS (1<<2) -#define REDISMODULE_HASH_EXISTS (1<<3) +#define REDISMODULE_HASH_NONE 0 +#define REDISMODULE_HASH_NX (1 << 0) +#define REDISMODULE_HASH_XX (1 << 1) +#define REDISMODULE_HASH_CFIELDS (1 << 2) +#define REDISMODULE_HASH_EXISTS (1 << 3) /* A special pointer that we can use between the core and the module to signal * field deletion, and that is impossible to be a valid pointer. */ -#define REDISMODULE_HASH_DELETE ((RedisModuleString*)(long)1) +#define REDISMODULE_HASH_DELETE ((RedisModuleString *)(long)1) /* Error messages. */ -#define REDISMODULE_ERRORMSG_WRONGTYPE "WRONGTYPE Operation against a key holding the wrong kind of value" +#define REDISMODULE_ERRORMSG_WRONGTYPE \ + "WRONGTYPE Operation against a key holding the wrong kind of value" -#define REDISMODULE_POSITIVE_INFINITE (1.0/0.0) -#define REDISMODULE_NEGATIVE_INFINITE (-1.0/0.0) +#define REDISMODULE_POSITIVE_INFINITE (1.0 / 0.0) +#define REDISMODULE_NEGATIVE_INFINITE (-1.0 / 0.0) -#define REDISMODULE_NOT_USED(V) ((void) V) +#define REDISMODULE_NOT_USED(V) ((void)V) /* ------------------------- End of common defines ------------------------ */ @@ -86,95 +87,142 @@ typedef struct RedisModuleType RedisModuleType; typedef struct RedisModuleDigest RedisModuleDigest; typedef struct RedisModuleBlockedClient RedisModuleBlockedClient; -typedef int (*RedisModuleCmdFunc) (RedisModuleCtx *ctx, RedisModuleString **argv, int argc); +typedef int (*RedisModuleCmdFunc)(RedisModuleCtx *ctx, RedisModuleString **argv, + int argc); typedef void *(*RedisModuleTypeLoadFunc)(RedisModuleIO *rdb, int encver); typedef void (*RedisModuleTypeSaveFunc)(RedisModuleIO *rdb, void *value); -typedef void (*RedisModuleTypeRewriteFunc)(RedisModuleIO *aof, RedisModuleString *key, void *value); +typedef void (*RedisModuleTypeRewriteFunc)(RedisModuleIO *aof, RedisModuleString *key, + void *value); typedef size_t (*RedisModuleTypeMemUsageFunc)(void *value); typedef void (*RedisModuleTypeDigestFunc)(RedisModuleDigest *digest, void *value); typedef void (*RedisModuleTypeFreeFunc)(void *value); #define REDISMODULE_TYPE_METHOD_VERSION 1 typedef struct RedisModuleTypeMethods { - uint64_t version; - RedisModuleTypeLoadFunc rdb_load; - RedisModuleTypeSaveFunc rdb_save; - RedisModuleTypeRewriteFunc aof_rewrite; - RedisModuleTypeMemUsageFunc mem_usage; - RedisModuleTypeDigestFunc digest; - RedisModuleTypeFreeFunc free; + uint64_t version; + RedisModuleTypeLoadFunc rdb_load; + RedisModuleTypeSaveFunc rdb_save; + RedisModuleTypeRewriteFunc aof_rewrite; + RedisModuleTypeMemUsageFunc mem_usage; + RedisModuleTypeDigestFunc digest; + RedisModuleTypeFreeFunc free; } RedisModuleTypeMethods; #define REDISMODULE_GET_API(name) \ - RedisModule_GetApi("RedisModule_" #name, ((void **)&RedisModule_ ## name)) + RedisModule_GetApi("RedisModule_" #name, ((void **)&RedisModule_##name)) #define REDISMODULE_API_FUNC(x) (*x) - void *REDISMODULE_API_FUNC(RedisModule_Alloc)(size_t bytes); void *REDISMODULE_API_FUNC(RedisModule_Realloc)(void *ptr, size_t bytes); void REDISMODULE_API_FUNC(RedisModule_Free)(void *ptr); void *REDISMODULE_API_FUNC(RedisModule_Calloc)(size_t nmemb, size_t size); char *REDISMODULE_API_FUNC(RedisModule_Strdup)(const char *str); int REDISMODULE_API_FUNC(RedisModule_GetApi)(const char *, void *); -int REDISMODULE_API_FUNC(RedisModule_CreateCommand)(RedisModuleCtx *ctx, const char *name, RedisModuleCmdFunc cmdfunc, const char *strflags, int firstkey, int lastkey, int keystep); -int REDISMODULE_API_FUNC(RedisModule_SetModuleAttribs)(RedisModuleCtx *ctx, const char *name, int ver, int apiver); +int REDISMODULE_API_FUNC(RedisModule_CreateCommand)(RedisModuleCtx *ctx, const char *name, + RedisModuleCmdFunc cmdfunc, + const char *strflags, int firstkey, + int lastkey, int keystep); +int REDISMODULE_API_FUNC(RedisModule_SetModuleAttribs)(RedisModuleCtx *ctx, + const char *name, int ver, + int apiver); int REDISMODULE_API_FUNC(RedisModule_WrongArity)(RedisModuleCtx *ctx); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithLongLong)(RedisModuleCtx *ctx, long long ll); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithLongLong)(RedisModuleCtx *ctx, + long long ll); int REDISMODULE_API_FUNC(RedisModule_GetSelectedDb)(RedisModuleCtx *ctx); int REDISMODULE_API_FUNC(RedisModule_SelectDb)(RedisModuleCtx *ctx, int newid); -void *REDISMODULE_API_FUNC(RedisModule_OpenKey)(RedisModuleCtx *ctx, RedisModuleString *keyname, int mode); +void *REDISMODULE_API_FUNC(RedisModule_OpenKey)(RedisModuleCtx *ctx, + RedisModuleString *keyname, int mode); void REDISMODULE_API_FUNC(RedisModule_CloseKey)(RedisModuleKey *kp); int REDISMODULE_API_FUNC(RedisModule_KeyType)(RedisModuleKey *kp); size_t REDISMODULE_API_FUNC(RedisModule_ValueLength)(RedisModuleKey *kp); -int REDISMODULE_API_FUNC(RedisModule_ListPush)(RedisModuleKey *kp, int where, RedisModuleString *ele); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ListPop)(RedisModuleKey *key, int where); -RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_Call)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...); -const char *REDISMODULE_API_FUNC(RedisModule_CallReplyProto)(RedisModuleCallReply *reply, size_t *len); +int REDISMODULE_API_FUNC(RedisModule_ListPush)(RedisModuleKey *kp, int where, + RedisModuleString *ele); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ListPop)(RedisModuleKey *key, + int where); +RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_Call)(RedisModuleCtx *ctx, + const char *cmdname, + const char *fmt, ...); +const char *REDISMODULE_API_FUNC(RedisModule_CallReplyProto)(RedisModuleCallReply *reply, + size_t *len); void REDISMODULE_API_FUNC(RedisModule_FreeCallReply)(RedisModuleCallReply *reply); int REDISMODULE_API_FUNC(RedisModule_CallReplyType)(RedisModuleCallReply *reply); long long REDISMODULE_API_FUNC(RedisModule_CallReplyInteger)(RedisModuleCallReply *reply); size_t REDISMODULE_API_FUNC(RedisModule_CallReplyLength)(RedisModuleCallReply *reply); -RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_CallReplyArrayElement)(RedisModuleCallReply *reply, size_t idx); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateString)(RedisModuleCtx *ctx, const char *ptr, size_t len); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromLongLong)(RedisModuleCtx *ctx, long long ll); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromString)(RedisModuleCtx *ctx, const RedisModuleString *str); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringPrintf)(RedisModuleCtx *ctx, const char *fmt, ...); -void REDISMODULE_API_FUNC(RedisModule_FreeString)(RedisModuleCtx *ctx, RedisModuleString *str); -const char *REDISMODULE_API_FUNC(RedisModule_StringPtrLen)(const RedisModuleString *str, size_t *len); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithError)(RedisModuleCtx *ctx, const char *err); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithSimpleString)(RedisModuleCtx *ctx, const char *msg); +RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_CallReplyArrayElement)( + RedisModuleCallReply *reply, size_t idx); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateString)(RedisModuleCtx *ctx, + const char *ptr, + size_t len); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromLongLong)( + RedisModuleCtx *ctx, long long ll); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromString)( + RedisModuleCtx *ctx, const RedisModuleString *str); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringPrintf)( + RedisModuleCtx *ctx, const char *fmt, ...); +void REDISMODULE_API_FUNC(RedisModule_FreeString)(RedisModuleCtx *ctx, + RedisModuleString *str); +const char *REDISMODULE_API_FUNC(RedisModule_StringPtrLen)(const RedisModuleString *str, + size_t *len); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithError)(RedisModuleCtx *ctx, + const char *err); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithSimpleString)(RedisModuleCtx *ctx, + const char *msg); int REDISMODULE_API_FUNC(RedisModule_ReplyWithArray)(RedisModuleCtx *ctx, long len); void REDISMODULE_API_FUNC(RedisModule_ReplySetArrayLength)(RedisModuleCtx *ctx, long len); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithStringBuffer)(RedisModuleCtx *ctx, const char *buf, size_t len); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithString)(RedisModuleCtx *ctx, RedisModuleString *str); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithStringBuffer)(RedisModuleCtx *ctx, + const char *buf, size_t len); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithString)(RedisModuleCtx *ctx, + RedisModuleString *str); int REDISMODULE_API_FUNC(RedisModule_ReplyWithNull)(RedisModuleCtx *ctx); int REDISMODULE_API_FUNC(RedisModule_ReplyWithDouble)(RedisModuleCtx *ctx, double d); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithCallReply)(RedisModuleCtx *ctx, RedisModuleCallReply *reply); -int REDISMODULE_API_FUNC(RedisModule_StringToLongLong)(const RedisModuleString *str, long long *ll); -int REDISMODULE_API_FUNC(RedisModule_StringToDouble)(const RedisModuleString *str, double *d); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithCallReply)(RedisModuleCtx *ctx, + RedisModuleCallReply *reply); +int REDISMODULE_API_FUNC(RedisModule_StringToLongLong)(const RedisModuleString *str, + long long *ll); +int REDISMODULE_API_FUNC(RedisModule_StringToDouble)(const RedisModuleString *str, + double *d); void REDISMODULE_API_FUNC(RedisModule_AutoMemory)(RedisModuleCtx *ctx); -int REDISMODULE_API_FUNC(RedisModule_Replicate)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...); +int REDISMODULE_API_FUNC(RedisModule_Replicate)(RedisModuleCtx *ctx, const char *cmdname, + const char *fmt, ...); int REDISMODULE_API_FUNC(RedisModule_ReplicateVerbatim)(RedisModuleCtx *ctx); -const char *REDISMODULE_API_FUNC(RedisModule_CallReplyStringPtr)(RedisModuleCallReply *reply, size_t *len); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromCallReply)(RedisModuleCallReply *reply); +const char *REDISMODULE_API_FUNC(RedisModule_CallReplyStringPtr)( + RedisModuleCallReply *reply, size_t *len); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromCallReply)( + RedisModuleCallReply *reply); int REDISMODULE_API_FUNC(RedisModule_DeleteKey)(RedisModuleKey *key); -int REDISMODULE_API_FUNC(RedisModule_StringSet)(RedisModuleKey *key, RedisModuleString *str); -char *REDISMODULE_API_FUNC(RedisModule_StringDMA)(RedisModuleKey *key, size_t *len, int mode); +int REDISMODULE_API_FUNC(RedisModule_StringSet)(RedisModuleKey *key, + RedisModuleString *str); +char *REDISMODULE_API_FUNC(RedisModule_StringDMA)(RedisModuleKey *key, size_t *len, + int mode); int REDISMODULE_API_FUNC(RedisModule_StringTruncate)(RedisModuleKey *key, size_t newlen); mstime_t REDISMODULE_API_FUNC(RedisModule_GetExpire)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_SetExpire)(RedisModuleKey *key, mstime_t expire); -int REDISMODULE_API_FUNC(RedisModule_ZsetAdd)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr); -int REDISMODULE_API_FUNC(RedisModule_ZsetIncrby)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr, double *newscore); -int REDISMODULE_API_FUNC(RedisModule_ZsetScore)(RedisModuleKey *key, RedisModuleString *ele, double *score); -int REDISMODULE_API_FUNC(RedisModule_ZsetRem)(RedisModuleKey *key, RedisModuleString *ele, int *deleted); +int REDISMODULE_API_FUNC(RedisModule_ZsetAdd)(RedisModuleKey *key, double score, + RedisModuleString *ele, int *flagsptr); +int REDISMODULE_API_FUNC(RedisModule_ZsetIncrby)(RedisModuleKey *key, double score, + RedisModuleString *ele, int *flagsptr, + double *newscore); +int REDISMODULE_API_FUNC(RedisModule_ZsetScore)(RedisModuleKey *key, + RedisModuleString *ele, double *score); +int REDISMODULE_API_FUNC(RedisModule_ZsetRem)(RedisModuleKey *key, RedisModuleString *ele, + int *deleted); void REDISMODULE_API_FUNC(RedisModule_ZsetRangeStop)(RedisModuleKey *key); -int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex); -int REDISMODULE_API_FUNC(RedisModule_ZsetLastInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex); -int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max); -int REDISMODULE_API_FUNC(RedisModule_ZsetLastInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ZsetRangeCurrentElement)(RedisModuleKey *key, double *score); +int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInScoreRange)(RedisModuleKey *key, + double min, double max, + int minex, int maxex); +int REDISMODULE_API_FUNC(RedisModule_ZsetLastInScoreRange)(RedisModuleKey *key, + double min, double max, + int minex, int maxex); +int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInLexRange)(RedisModuleKey *key, + RedisModuleString *min, + RedisModuleString *max); +int REDISMODULE_API_FUNC(RedisModule_ZsetLastInLexRange)(RedisModuleKey *key, + RedisModuleString *min, + RedisModuleString *max); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ZsetRangeCurrentElement)( + RedisModuleKey *key, double *score); int REDISMODULE_API_FUNC(RedisModule_ZsetRangeNext)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_ZsetRangePrev)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_ZsetRangeEndReached)(RedisModuleKey *key); @@ -184,31 +232,49 @@ int REDISMODULE_API_FUNC(RedisModule_IsKeysPositionRequest)(RedisModuleCtx *ctx) void REDISMODULE_API_FUNC(RedisModule_KeyAtPos)(RedisModuleCtx *ctx, int pos); unsigned long long REDISMODULE_API_FUNC(RedisModule_GetClientId)(RedisModuleCtx *ctx); void *REDISMODULE_API_FUNC(RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes); -RedisModuleType *REDISMODULE_API_FUNC(RedisModule_CreateDataType)(RedisModuleCtx *ctx, const char *name, int encver, RedisModuleTypeMethods *typemethods); -int REDISMODULE_API_FUNC(RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, RedisModuleType *mt, void *value); +RedisModuleType *REDISMODULE_API_FUNC(RedisModule_CreateDataType)( + RedisModuleCtx *ctx, const char *name, int encver, + RedisModuleTypeMethods *typemethods); +int REDISMODULE_API_FUNC(RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, + RedisModuleType *mt, + void *value); RedisModuleType *REDISMODULE_API_FUNC(RedisModule_ModuleTypeGetType)(RedisModuleKey *key); void *REDISMODULE_API_FUNC(RedisModule_ModuleTypeGetValue)(RedisModuleKey *key); void REDISMODULE_API_FUNC(RedisModule_SaveUnsigned)(RedisModuleIO *io, uint64_t value); uint64_t REDISMODULE_API_FUNC(RedisModule_LoadUnsigned)(RedisModuleIO *io); void REDISMODULE_API_FUNC(RedisModule_SaveSigned)(RedisModuleIO *io, int64_t value); int64_t REDISMODULE_API_FUNC(RedisModule_LoadSigned)(RedisModuleIO *io); -void REDISMODULE_API_FUNC(RedisModule_EmitAOF)(RedisModuleIO *io, const char *cmdname, const char *fmt, ...); -void REDISMODULE_API_FUNC(RedisModule_SaveString)(RedisModuleIO *io, RedisModuleString *s); -void REDISMODULE_API_FUNC(RedisModule_SaveStringBuffer)(RedisModuleIO *io, const char *str, size_t len); +void REDISMODULE_API_FUNC(RedisModule_EmitAOF)(RedisModuleIO *io, const char *cmdname, + const char *fmt, ...); +void REDISMODULE_API_FUNC(RedisModule_SaveString)(RedisModuleIO *io, + RedisModuleString *s); +void REDISMODULE_API_FUNC(RedisModule_SaveStringBuffer)(RedisModuleIO *io, + const char *str, size_t len); RedisModuleString *REDISMODULE_API_FUNC(RedisModule_LoadString)(RedisModuleIO *io); -char *REDISMODULE_API_FUNC(RedisModule_LoadStringBuffer)(RedisModuleIO *io, size_t *lenptr); +char *REDISMODULE_API_FUNC(RedisModule_LoadStringBuffer)(RedisModuleIO *io, + size_t *lenptr); void REDISMODULE_API_FUNC(RedisModule_SaveDouble)(RedisModuleIO *io, double value); double REDISMODULE_API_FUNC(RedisModule_LoadDouble)(RedisModuleIO *io); void REDISMODULE_API_FUNC(RedisModule_SaveFloat)(RedisModuleIO *io, float value); float REDISMODULE_API_FUNC(RedisModule_LoadFloat)(RedisModuleIO *io); -void REDISMODULE_API_FUNC(RedisModule_Log)(RedisModuleCtx *ctx, const char *level, const char *fmt, ...); -void REDISMODULE_API_FUNC(RedisModule_LogIOError)(RedisModuleIO *io, const char *levelstr, const char *fmt, ...); -int REDISMODULE_API_FUNC(RedisModule_StringAppendBuffer)(RedisModuleCtx *ctx, RedisModuleString *str, const char *buf, size_t len); -void REDISMODULE_API_FUNC(RedisModule_RetainString)(RedisModuleCtx *ctx, RedisModuleString *str); -int REDISMODULE_API_FUNC(RedisModule_StringCompare)(RedisModuleString *a, RedisModuleString *b); +void REDISMODULE_API_FUNC(RedisModule_Log)(RedisModuleCtx *ctx, const char *level, + const char *fmt, ...); +void REDISMODULE_API_FUNC(RedisModule_LogIOError)(RedisModuleIO *io, const char *levelstr, + const char *fmt, ...); +int REDISMODULE_API_FUNC(RedisModule_StringAppendBuffer)(RedisModuleCtx *ctx, + RedisModuleString *str, + const char *buf, size_t len); +void REDISMODULE_API_FUNC(RedisModule_RetainString)(RedisModuleCtx *ctx, + RedisModuleString *str); +int REDISMODULE_API_FUNC(RedisModule_StringCompare)(RedisModuleString *a, + RedisModuleString *b); RedisModuleCtx *REDISMODULE_API_FUNC(RedisModule_GetContextFromIO)(RedisModuleIO *io); -RedisModuleBlockedClient *REDISMODULE_API_FUNC(RedisModule_BlockClient)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(void*), long long timeout_ms); -int REDISMODULE_API_FUNC(RedisModule_UnblockClient)(RedisModuleBlockedClient *bc, void *privdata); +RedisModuleBlockedClient *REDISMODULE_API_FUNC(RedisModule_BlockClient)( + RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, + RedisModuleCmdFunc timeout_callback, void (*free_privdata)(void *), + long long timeout_ms); +int REDISMODULE_API_FUNC(RedisModule_UnblockClient)(RedisModuleBlockedClient *bc, + void *privdata); int REDISMODULE_API_FUNC(RedisModule_IsBlockedReplyRequest)(RedisModuleCtx *ctx); int REDISMODULE_API_FUNC(RedisModule_IsBlockedTimeoutRequest)(RedisModuleCtx *ctx); void *REDISMODULE_API_FUNC(RedisModule_GetBlockedClientPrivateData)(RedisModuleCtx *ctx); @@ -216,115 +282,116 @@ int REDISMODULE_API_FUNC(RedisModule_AbortBlock)(RedisModuleBlockedClient *bc); long long REDISMODULE_API_FUNC(RedisModule_Milliseconds)(void); /* This is included inline inside each Redis module. */ -static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) __attribute__((unused)); +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) + __attribute__((unused)); static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { - void *getapifuncptr = ((void**)ctx)[0]; - RedisModule_GetApi = (int (*)(const char *, void *)) (unsigned long)getapifuncptr; - REDISMODULE_GET_API(Alloc); - REDISMODULE_GET_API(Calloc); - REDISMODULE_GET_API(Free); - REDISMODULE_GET_API(Realloc); - REDISMODULE_GET_API(Strdup); - REDISMODULE_GET_API(CreateCommand); - REDISMODULE_GET_API(SetModuleAttribs); - REDISMODULE_GET_API(WrongArity); - REDISMODULE_GET_API(ReplyWithLongLong); - REDISMODULE_GET_API(ReplyWithError); - REDISMODULE_GET_API(ReplyWithSimpleString); - REDISMODULE_GET_API(ReplyWithArray); - REDISMODULE_GET_API(ReplySetArrayLength); - REDISMODULE_GET_API(ReplyWithStringBuffer); - REDISMODULE_GET_API(ReplyWithString); - REDISMODULE_GET_API(ReplyWithNull); - REDISMODULE_GET_API(ReplyWithCallReply); - REDISMODULE_GET_API(ReplyWithDouble); - REDISMODULE_GET_API(ReplySetArrayLength); - REDISMODULE_GET_API(GetSelectedDb); - REDISMODULE_GET_API(SelectDb); - REDISMODULE_GET_API(OpenKey); - REDISMODULE_GET_API(CloseKey); - REDISMODULE_GET_API(KeyType); - REDISMODULE_GET_API(ValueLength); - REDISMODULE_GET_API(ListPush); - REDISMODULE_GET_API(ListPop); - REDISMODULE_GET_API(StringToLongLong); - REDISMODULE_GET_API(StringToDouble); - REDISMODULE_GET_API(Call); - REDISMODULE_GET_API(CallReplyProto); - REDISMODULE_GET_API(FreeCallReply); - REDISMODULE_GET_API(CallReplyInteger); - REDISMODULE_GET_API(CallReplyType); - REDISMODULE_GET_API(CallReplyLength); - REDISMODULE_GET_API(CallReplyArrayElement); - REDISMODULE_GET_API(CallReplyStringPtr); - REDISMODULE_GET_API(CreateStringFromCallReply); - REDISMODULE_GET_API(CreateString); - REDISMODULE_GET_API(CreateStringFromLongLong); - REDISMODULE_GET_API(CreateStringFromString); - REDISMODULE_GET_API(CreateStringPrintf); - REDISMODULE_GET_API(FreeString); - REDISMODULE_GET_API(StringPtrLen); - REDISMODULE_GET_API(AutoMemory); - REDISMODULE_GET_API(Replicate); - REDISMODULE_GET_API(ReplicateVerbatim); - REDISMODULE_GET_API(DeleteKey); - REDISMODULE_GET_API(StringSet); - REDISMODULE_GET_API(StringDMA); - REDISMODULE_GET_API(StringTruncate); - REDISMODULE_GET_API(GetExpire); - REDISMODULE_GET_API(SetExpire); - REDISMODULE_GET_API(ZsetAdd); - REDISMODULE_GET_API(ZsetIncrby); - REDISMODULE_GET_API(ZsetScore); - REDISMODULE_GET_API(ZsetRem); - REDISMODULE_GET_API(ZsetRangeStop); - REDISMODULE_GET_API(ZsetFirstInScoreRange); - REDISMODULE_GET_API(ZsetLastInScoreRange); - REDISMODULE_GET_API(ZsetFirstInLexRange); - REDISMODULE_GET_API(ZsetLastInLexRange); - REDISMODULE_GET_API(ZsetRangeCurrentElement); - REDISMODULE_GET_API(ZsetRangeNext); - REDISMODULE_GET_API(ZsetRangePrev); - REDISMODULE_GET_API(ZsetRangeEndReached); - REDISMODULE_GET_API(HashSet); - REDISMODULE_GET_API(HashGet); - REDISMODULE_GET_API(IsKeysPositionRequest); - REDISMODULE_GET_API(KeyAtPos); - REDISMODULE_GET_API(GetClientId); - REDISMODULE_GET_API(PoolAlloc); - REDISMODULE_GET_API(CreateDataType); - REDISMODULE_GET_API(ModuleTypeSetValue); - REDISMODULE_GET_API(ModuleTypeGetType); - REDISMODULE_GET_API(ModuleTypeGetValue); - REDISMODULE_GET_API(SaveUnsigned); - REDISMODULE_GET_API(LoadUnsigned); - REDISMODULE_GET_API(SaveSigned); - REDISMODULE_GET_API(LoadSigned); - REDISMODULE_GET_API(SaveString); - REDISMODULE_GET_API(SaveStringBuffer); - REDISMODULE_GET_API(LoadString); - REDISMODULE_GET_API(LoadStringBuffer); - REDISMODULE_GET_API(SaveDouble); - REDISMODULE_GET_API(LoadDouble); - REDISMODULE_GET_API(SaveFloat); - REDISMODULE_GET_API(LoadFloat); - REDISMODULE_GET_API(EmitAOF); - REDISMODULE_GET_API(Log); - REDISMODULE_GET_API(LogIOError); - REDISMODULE_GET_API(StringAppendBuffer); - REDISMODULE_GET_API(RetainString); - REDISMODULE_GET_API(StringCompare); - REDISMODULE_GET_API(GetContextFromIO); - REDISMODULE_GET_API(BlockClient); - REDISMODULE_GET_API(UnblockClient); - REDISMODULE_GET_API(IsBlockedReplyRequest); - REDISMODULE_GET_API(IsBlockedTimeoutRequest); - REDISMODULE_GET_API(GetBlockedClientPrivateData); - REDISMODULE_GET_API(AbortBlock); - REDISMODULE_GET_API(Milliseconds); - - RedisModule_SetModuleAttribs(ctx,name,ver,apiver); - return REDISMODULE_OK; + void *getapifuncptr = ((void **)ctx)[0]; + RedisModule_GetApi = (int (*)(const char *, void *))(unsigned long)getapifuncptr; + REDISMODULE_GET_API(Alloc); + REDISMODULE_GET_API(Calloc); + REDISMODULE_GET_API(Free); + REDISMODULE_GET_API(Realloc); + REDISMODULE_GET_API(Strdup); + REDISMODULE_GET_API(CreateCommand); + REDISMODULE_GET_API(SetModuleAttribs); + REDISMODULE_GET_API(WrongArity); + REDISMODULE_GET_API(ReplyWithLongLong); + REDISMODULE_GET_API(ReplyWithError); + REDISMODULE_GET_API(ReplyWithSimpleString); + REDISMODULE_GET_API(ReplyWithArray); + REDISMODULE_GET_API(ReplySetArrayLength); + REDISMODULE_GET_API(ReplyWithStringBuffer); + REDISMODULE_GET_API(ReplyWithString); + REDISMODULE_GET_API(ReplyWithNull); + REDISMODULE_GET_API(ReplyWithCallReply); + REDISMODULE_GET_API(ReplyWithDouble); + REDISMODULE_GET_API(ReplySetArrayLength); + REDISMODULE_GET_API(GetSelectedDb); + REDISMODULE_GET_API(SelectDb); + REDISMODULE_GET_API(OpenKey); + REDISMODULE_GET_API(CloseKey); + REDISMODULE_GET_API(KeyType); + REDISMODULE_GET_API(ValueLength); + REDISMODULE_GET_API(ListPush); + REDISMODULE_GET_API(ListPop); + REDISMODULE_GET_API(StringToLongLong); + REDISMODULE_GET_API(StringToDouble); + REDISMODULE_GET_API(Call); + REDISMODULE_GET_API(CallReplyProto); + REDISMODULE_GET_API(FreeCallReply); + REDISMODULE_GET_API(CallReplyInteger); + REDISMODULE_GET_API(CallReplyType); + REDISMODULE_GET_API(CallReplyLength); + REDISMODULE_GET_API(CallReplyArrayElement); + REDISMODULE_GET_API(CallReplyStringPtr); + REDISMODULE_GET_API(CreateStringFromCallReply); + REDISMODULE_GET_API(CreateString); + REDISMODULE_GET_API(CreateStringFromLongLong); + REDISMODULE_GET_API(CreateStringFromString); + REDISMODULE_GET_API(CreateStringPrintf); + REDISMODULE_GET_API(FreeString); + REDISMODULE_GET_API(StringPtrLen); + REDISMODULE_GET_API(AutoMemory); + REDISMODULE_GET_API(Replicate); + REDISMODULE_GET_API(ReplicateVerbatim); + REDISMODULE_GET_API(DeleteKey); + REDISMODULE_GET_API(StringSet); + REDISMODULE_GET_API(StringDMA); + REDISMODULE_GET_API(StringTruncate); + REDISMODULE_GET_API(GetExpire); + REDISMODULE_GET_API(SetExpire); + REDISMODULE_GET_API(ZsetAdd); + REDISMODULE_GET_API(ZsetIncrby); + REDISMODULE_GET_API(ZsetScore); + REDISMODULE_GET_API(ZsetRem); + REDISMODULE_GET_API(ZsetRangeStop); + REDISMODULE_GET_API(ZsetFirstInScoreRange); + REDISMODULE_GET_API(ZsetLastInScoreRange); + REDISMODULE_GET_API(ZsetFirstInLexRange); + REDISMODULE_GET_API(ZsetLastInLexRange); + REDISMODULE_GET_API(ZsetRangeCurrentElement); + REDISMODULE_GET_API(ZsetRangeNext); + REDISMODULE_GET_API(ZsetRangePrev); + REDISMODULE_GET_API(ZsetRangeEndReached); + REDISMODULE_GET_API(HashSet); + REDISMODULE_GET_API(HashGet); + REDISMODULE_GET_API(IsKeysPositionRequest); + REDISMODULE_GET_API(KeyAtPos); + REDISMODULE_GET_API(GetClientId); + REDISMODULE_GET_API(PoolAlloc); + REDISMODULE_GET_API(CreateDataType); + REDISMODULE_GET_API(ModuleTypeSetValue); + REDISMODULE_GET_API(ModuleTypeGetType); + REDISMODULE_GET_API(ModuleTypeGetValue); + REDISMODULE_GET_API(SaveUnsigned); + REDISMODULE_GET_API(LoadUnsigned); + REDISMODULE_GET_API(SaveSigned); + REDISMODULE_GET_API(LoadSigned); + REDISMODULE_GET_API(SaveString); + REDISMODULE_GET_API(SaveStringBuffer); + REDISMODULE_GET_API(LoadString); + REDISMODULE_GET_API(LoadStringBuffer); + REDISMODULE_GET_API(SaveDouble); + REDISMODULE_GET_API(LoadDouble); + REDISMODULE_GET_API(SaveFloat); + REDISMODULE_GET_API(LoadFloat); + REDISMODULE_GET_API(EmitAOF); + REDISMODULE_GET_API(Log); + REDISMODULE_GET_API(LogIOError); + REDISMODULE_GET_API(StringAppendBuffer); + REDISMODULE_GET_API(RetainString); + REDISMODULE_GET_API(StringCompare); + REDISMODULE_GET_API(GetContextFromIO); + REDISMODULE_GET_API(BlockClient); + REDISMODULE_GET_API(UnblockClient); + REDISMODULE_GET_API(IsBlockedReplyRequest); + REDISMODULE_GET_API(IsBlockedTimeoutRequest); + REDISMODULE_GET_API(GetBlockedClientPrivateData); + REDISMODULE_GET_API(AbortBlock); + REDISMODULE_GET_API(Milliseconds); + + RedisModule_SetModuleAttribs(ctx, name, ver, apiver); + return REDISMODULE_OK; } #else diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ffc44daa049a..33f1615580a6 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -92,7 +92,7 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, std::vector results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data.data()); RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { DataT result; @@ -114,9 +114,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const Callback &subscribe, const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, - const GcsTableNotificationMode notification_mode, + const GcsChangeMode change_mode, const std::vector &data) { - RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE); + RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; return Subscribe(driver_id, client_id, subscribe_wrapper, done); @@ -141,7 +141,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data.data()); ID id; if (root->id()->size() > 0) { id = from_flatbuf(*root->id()); @@ -153,7 +153,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien data_root->UnPackTo(&result); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->notification_mode(), results); + subscribe(client_, id, root->change_mode(), results); } } }; @@ -339,6 +339,155 @@ std::string Set::DebugString() const { return result.str(); } +template +Status Hash::Update(const DriverID &driver_id, const ID &id, + const DataMap &data_map, const HashCallback &done) { + num_adds_++; + auto callback = [this, id, data_map, done](const CallbackReply &reply) { + if (done != nullptr) { + (done)(client_, id, data_map); + } + }; + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(data_map.size() * 2); + for (auto const &pair : data_map) { + // Add the key. + data_vec.push_back(fbb.CreateString(pair.first)); + flatbuffers::FlatBufferBuilder fbb_data; + fbb_data.ForceDefaults(true); + fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); + std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), + fbb_data.GetSize()); + // Add the value. + data_vec.push_back(fbb.CreateString(data)); + } + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) { + num_removes_++; + auto callback = [this, id, keys, remove_callback](const CallbackReply &reply) { + if (remove_callback != nullptr) { + (remove_callback)(client_, id, keys); + } + }; + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(keys.size()); + // Add the keys. + for (auto const &key : keys) { + data_vec.push_back(fbb.CreateString(key)); + } + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), + fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +std::string Hash::DebugString() const { + std::stringstream result; + result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_ + << ", num removes: " << num_removes_; + return result.str(); +} + +template +Status Hash::Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) { + num_lookups_++; + auto callback = [this, id, lookup](const CallbackReply &reply) { + if (lookup != nullptr) { + DataMap results; + if (!reply.IsNil()) { + const auto data = reply.ReadAsString(); + auto root = flatbuffers::GetRoot(data.data()); + RAY_CHECK(from_flatbuf(*root->id()) == id); + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + results.emplace(key, std::move(result)); + } + } + lookup(client_, id, results); + } + }; + std::vector nil; + return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), + prefix_, pubsub_channel_, std::move(callback)); +} + +template +Status Hash::Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) { + RAY_CHECK(subscribe_callback_index_ == -1) + << "Client called Subscribe twice on the same table"; + auto callback = [this, subscribe, done](const CallbackReply &reply) { + const auto data = reply.ReadAsPubsubData(); + if (data.empty()) { + // No notification data is provided. This is the callback for the + // initial subscription request. + if (done != nullptr) { + done(client_); + } + } else { + // Data is provided. This is the callback for a message. + if (subscribe != nullptr) { + // Parse the notification. + auto root = flatbuffers::GetRoot(data.data()); + DataMap data_map; + ID id; + if (root->id()->size() > 0) { + id = from_flatbuf(*root->id()); + } + if (root->change_mode() == GcsChangeMode::REMOVE) { + for (size_t i = 0; i < root->entries()->size(); i++) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + data_map.emplace(key, std::shared_ptr()); + } + } else { + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + data_map.emplace(key, std::move(result)); + } + } + subscribe(client_, id, root->change_mode(), data_map); + } + } + }; + + subscribe_callback_index_ = 1; + for (auto &context : shard_contexts_) { + RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback, + &subscribe_callback_index_)); + } + return Status::OK(); +} + Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); @@ -525,8 +674,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { auto connected_client_id = ClientID::FromBinary(data.client_id); - RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " - << client_id_; + RAY_CHECK(client_id_ == connected_client_id) + << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } @@ -555,8 +704,8 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Callback for a notification from the client table. auto notification_callback = [this]( - AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + AsyncGcsClient *client, const UniqueID &log_key, + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); std::unordered_map connected_nodes; std::unordered_map disconnected_nodes; @@ -648,8 +797,8 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( - ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { + ray::gcs::AsyncGcsClient *client, const UniqueID &id, + const ActorCheckpointIdDataT &data) { std::shared_ptr copy = std::make_shared(data); copy->timestamps.push_back(current_sys_time_ms()); @@ -668,7 +817,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( - ray::gcs::AsyncGcsClient *client, const UniqueID &id) { + ray::gcs::AsyncGcsClient *client, const UniqueID &id) { std::shared_ptr data = std::make_shared(); data->actor_id = id.Binary(); @@ -696,6 +845,9 @@ template class Log; template class Table; template class Table; +template class Log; +template class Hash; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index af42509bda96..6a1d502a7f54 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -75,9 +75,9 @@ class Log : public LogInterface, virtual public PubsubInterface { using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + using NotificationCallback = std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -214,7 +214,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// to subscribe to all modifications, or to subscribe only to keys that it /// requests notifications for. This may only be called once per Log /// instance. This function is different from public version due to - /// an additional parameter notification_mode in NotificationCallback. Therefore this + /// an additional parameter change_mode in NotificationCallback. Therefore this /// function supports notifications of remove operations. /// /// \param driver_id The ID of the job (= driver). @@ -451,6 +451,157 @@ class Set : private Log, using Log::num_lookups_; }; +template +class HashInterface { + public: + using DataT = typename Data::NativeTableType; + using DataMap = std::unordered_map>; + // Reuse Log's SubscriptionCallback when Subscribe is successfully called. + using SubscriptionCallback = typename Log::SubscriptionCallback; + + /// The callback function used by function Update & Lookup. + /// + /// \param client The client on which the RemoveEntries is called. + /// \param id The ID of the Hash Table whose entries are removed. + /// \param data Map data contains the change to the Hash Table. + /// \return Void + using HashCallback = + std::function; + + /// The callback function used by function RemoveEntries. + /// + /// \param client The client on which the RemoveEntries is called. + /// \param id The ID of the Hash Table whose entries are removed. + /// \param keys The keys that are moved from this Hash Table. + /// \return Void + using HashRemoveCallback = std::function &keys)>; + + /// The notification function used by function Subscribe. + /// + /// \param client The client on which the Subscribe is called. + /// \param change_mode The mode to identify the data is removed or updated. + /// \param data Map data contains the change to the Hash Table. + /// \return Void + using HashNotificationCallback = + std::function; + + /// Add entries of a hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is added to the GCS. + /// \param pairs Map data to add to the hash table. + /// \param done HashCallback that is called once the request data has been written to + /// the GCS. + /// \return Status + virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + const HashCallback &done) = 0; + + /// Remove entries from the hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is removed from the GCS. + /// \param keys The entry keys of the hash table. + /// \param remove_callback HashRemoveCallback that is called once the data has been + /// written to the GCS no matter whether the key exists in the hash table. + /// \return Status + virtual Status RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) = 0; + + /// Lookup the map data of a hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is looked up in the GCS. + /// \param lookup HashCallback that is called after lookup. If the callback is + /// called with an empty hash table, then there was no data in the callback. + /// \return Status + virtual Status Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) = 0; + + /// Subscribe to any Update or Remove operations to this hash table. + /// + /// \param driver_id The ID of the driver. + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each Update to the table will be received. Else, only + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. + /// \param subscribe HashNotificationCallback that is called on each received message. + /// \param done SubscriptionCallback that is called when subscription is complete and + /// we are ready to receive messages. + /// \return Status + virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) = 0; + + virtual ~HashInterface(){}; +}; + +template +class Hash : private Log, + public HashInterface, + virtual public PubsubInterface { + public: + using DataT = typename Log::DataT; + using DataMap = std::unordered_map>; + using HashCallback = typename HashInterface::HashCallback; + using HashRemoveCallback = typename HashInterface::HashRemoveCallback; + using HashNotificationCallback = + typename HashInterface::HashNotificationCallback; + using SubscriptionCallback = typename Log::SubscriptionCallback; + + Hash(const std::vector> &contexts, AsyncGcsClient *client) + : Log(contexts, client) {} + + using Log::RequestNotifications; + using Log::CancelNotifications; + + Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + const HashCallback &done) override; + + Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) override; + + Status Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) override; + + Status RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) override; + + /// Returns debug string for class. + /// + /// \return string. + std::string DebugString() const; + + protected: + using Log::shard_contexts_; + using Log::client_; + using Log::pubsub_channel_; + using Log::prefix_; + using Log::subscribe_callback_index_; + using Log::GetRedisContext; + + int64_t num_adds_ = 0; + int64_t num_removes_ = 0; + using Log::num_lookups_; +}; + +class DynamicResourceTable : public Hash { + public: + DynamicResourceTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Hash(contexts, client) { + pubsub_channel_ = TablePubsub::NODE_RESOURCE; + prefix_ = TablePrefix::NODE_RESOURCE; + }; + + virtual ~DynamicResourceTable(){}; +}; + class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 1f05559f4b87..5b6794a505d3 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -11,16 +11,16 @@ namespace { /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. -void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, +void UpdateObjectLocations(const GcsChangeMode change_mode, const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. - // with GcsTableNotificationMode, we can determine whether the update mode is + // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { ClientID client_id = ClientID::FromBinary(object_table_data.manager); - if (notification_mode != GcsTableNotificationMode::REMOVE) { + if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { client_ids->erase(client_id); @@ -39,37 +39,36 @@ void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, } // namespace void ObjectDirectory::RegisterBackend() { - auto object_notification_callback = [this]( - gcs::AsyncGcsClient *client, const ObjectID &object_id, - const GcsTableNotificationMode notification_mode, - const std::vector &location_updates) { - // Objects are added to this map in SubscribeObjectLocations. - auto it = listeners_.find(object_id); - // Do nothing for objects we are not listening for. - if (it == listeners_.end()) { - return; - } - - // Once this flag is set to true, it should never go back to false. - it->second.subscribed = true; - - // Update entries for this object. - UpdateObjectLocations(notification_mode, location_updates, - gcs_client_->client_table(), - &it->second.current_object_locations); - // Copy the callbacks so that the callbacks can unsubscribe without interrupting - // looping over the callbacks. - auto callbacks = it->second.callbacks; - // Call all callbacks associated with the object id locations we have - // received. This notifies the client even if the list of locations is - // empty, since this may indicate that the objects have been evicted from - // all nodes. - for (const auto &callback_pair : callbacks) { - // It is safe to call the callback directly since this is already running - // in the subscription callback stack. - callback_pair.second(object_id, it->second.current_object_locations); - } - }; + auto object_notification_callback = + [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, + const GcsChangeMode change_mode, + const std::vector &location_updates) { + // Objects are added to this map in SubscribeObjectLocations. + auto it = listeners_.find(object_id); + // Do nothing for objects we are not listening for. + if (it == listeners_.end()) { + return; + } + + // Once this flag is set to true, it should never go back to false. + it->second.subscribed = true; + + // Update entries for this object. + UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(), + &it->second.current_object_locations); + // Copy the callbacks so that the callbacks can unsubscribe without interrupting + // looping over the callbacks. + auto callbacks = it->second.callbacks; + // Call all callbacks associated with the object id locations we have + // received. This notifies the client even if the list of locations is + // empty, since this may indicate that the objects have been evicted from + // all nodes. + for (const auto &callback_pair : callbacks) { + // It is safe to call the callback directly since this is already running + // in the subscription callback stack. + callback_pair.second(object_id, it->second.current_object_locations); + } + }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), object_notification_callback, nullptr)); @@ -135,8 +134,7 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) { if (listener.second.current_object_locations.count(client_id) > 0) { // If the subscribed object has the removed client as a location, update // its locations with an empty update so that the location will be removed. - UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {}, - gcs_client_->client_table(), + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_->client_table(), &listener.second.current_object_locations); // Re-call all the subscribed callbacks for the object, since its // locations have changed. @@ -213,7 +211,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; - UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates, + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, gcs_client_->client_table(), &client_ids); // It is safe to call the callback directly since this is already running // in the GCS client's lookup callback stack. diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index f1169605134a..55aa59124a99 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -70,11 +70,11 @@ class MockServer { void HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - object_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(object_manager_socket_), @@ -240,16 +240,17 @@ class StressTestObjectManager : public TestObjectManagerBase { void WaitConnections() { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback([this]( - gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { - num_connected_clients += 1; - } - if (num_connected_clients == 2) { - StartTests(); - } - }); + gcs_client_1->client_table().RegisterClientAddedCallback( + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); } void StartTests() { diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 012c306938d6..ee6c78d8ed42 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -64,11 +64,11 @@ class MockServer { void HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - object_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(object_manager_socket_), @@ -219,16 +219,17 @@ class TestObjectManager : public TestObjectManagerBase { void WaitConnections() { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback([this]( - gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { - num_connected_clients += 1; - } - if (num_connected_clients == 2) { - StartTests(); - } - }); + gcs_client_1->client_table().RegisterClientAddedCallback( + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); } void StartTests() { @@ -291,9 +292,10 @@ class TestObjectManager : public TestObjectManagerBase { UniqueID sub_id = ray::UniqueID::FromRandom(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( - sub_id, object_1, [this, sub_id, object_1, object_2]( - const ray::ObjectID &object_id, - const std::unordered_set &clients) { + sub_id, object_1, + [this, sub_id, object_1, object_2]( + const ray::ObjectID &object_id, + const std::unordered_set &clients) { if (!clients.empty()) { TestWaitWhileSubscribed(sub_id, object_1, object_2); } diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto new file mode 100644 index 000000000000..8a82da1c77fd --- /dev/null +++ b/src/ray/protobuf/node_manager.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package ray.rpc; + +message ForwardTaskRequest { + // The ID of the task to be forwarded. + bytes task_id = 1; + // The tasks in the uncommitted lineage of the forwarded task. This + // should include task_id. + // TODO(hchen): Currently, `uncommitted_tasks` are represented as + // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data + // structure is being used in many places. We should move Task and all related data + // strucutres to protobuf. + repeated bytes uncommitted_tasks = 2; +} + +message ForwardTaskReply { +} + +// Service for inter-node-manager communication. +service NodeManagerService { + // Forward a task and its uncommitted lineage to the remote node manager. + rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); +} diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc index 952088bb2f4b..d6359ae281dd 100644 --- a/src/ray/raylet/client_connection_test.cc +++ b/src/ray/raylet/client_connection_test.cc @@ -73,9 +73,9 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) { ClientHandler client_handler = [](LocalClientConnection &client) {}; - MessageHandler noop_handler = []( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = + [](std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; std::shared_ptr reader = NULL; @@ -120,9 +120,9 @@ TEST_F(ClientConnectionTest, SimpleAsyncError) { ClientHandler client_handler = [](LocalClientConnection &client) {}; - MessageHandler noop_handler = []( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = + [](std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; auto writer = LocalClientConnection::Create( client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); @@ -142,9 +142,9 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { ClientHandler client_handler = [](LocalClientConnection &client) {}; - MessageHandler noop_handler = []( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = + [](std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; auto writer = LocalClientConnection::Create( client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 2afcba18c356..319d29d4a93a 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -307,15 +307,15 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo * Method: nativeSetResource * Signature: (JLjava/lang/String;D[B)V */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(JNIEnv *env, jclass, - jlong client, jstring resourceName, jdouble capacity, jbyteArray nodeId) { +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( + JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, + jbyteArray nodeId) { auto raylet_client = reinterpret_cast(client); UniqueIdFromJByteArray node_id(env, nodeId); const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - auto status = raylet_client->SetResource(native_resource_name, - static_cast(capacity), node_id.GetId()); + auto status = raylet_client->SetResource( + native_resource_name, static_cast(capacity), node_id.GetId()); env->ReleaseStringUTFChars(resourceName, native_resource_name); ThrowRayExceptionIfNotOK(env, status); } diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 910c3481bf58..32dddada5244 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -68,7 +68,7 @@ Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { auto tasks = task_request.uncommitted_tasks(); for (auto it = tasks->begin(); it != tasks->end(); it++) { const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED_REMOTE)); + RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); } } @@ -108,38 +108,23 @@ bool Lineage::SetEntry(const Task &task, GcsStatus status) { auto task_id = task.GetTaskSpecification().TaskId(); auto it = entries_.find(task_id); bool updated = false; - std::unordered_set old_parents; if (it != entries_.end()) { if (it->second.SetStatus(status)) { - // The task's spec may have changed, so record its old dependencies. - old_parents = it->second.GetParentTaskIds(); - // SetStatus() would check if the new status is greater, - // if it succeeds, go ahead to update the task field. - it->second.UpdateTaskData(task); + // We assume here that the new `task` has the same fields as the task + // already in the lineage cache. If this is not true, then it is + // necessary to update the task data of the existing lineage cache entry + // with LineageEntry::UpdateTaskData. updated = true; } } else { LineageEntry new_entry(task, status); it = entries_.emplace(std::make_pair(task_id, std::move(new_entry))).first; updated = true; - } - // If the task data was updated, then record which tasks it depends on. Add - // all new tasks that it depends on and remove any old tasks that it no - // longer depends on. - // TODO(swang): Updating the task data every time could be inefficient for - // tasks that have lots of dependencies and/or large specs. A flag could be - // passed in for tasks whose data has not changed. - if (updated) { + // New task data was added to the local cache, so record which tasks it + // depends on. Add all new tasks that it depends on. for (const auto &parent_id : it->second.GetParentTaskIds()) { - if (old_parents.count(parent_id) == 0) { - AddChild(parent_id, task_id); - } else { - old_parents.erase(parent_id); - } - } - for (const auto &old_parent_id : old_parents) { - RemoveChild(old_parent_id, task_id); + AddChild(parent_id, task_id); } } return updated; @@ -198,15 +183,15 @@ LineageCache::LineageCache(const ClientID &client_id, /// A helper function to add some uncommitted lineage to the local cache. void LineageCache::AddUncommittedLineage(const TaskID &task_id, - const Lineage &uncommitted_lineage, - std::unordered_set &subscribe_tasks) { + const Lineage &uncommitted_lineage) { + RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " << client_id_; // If the entry is not found in the lineage to merge, then we stop since // there is nothing to copy into the merged lineage. auto entry = uncommitted_lineage.GetEntry(task_id); if (!entry) { return; } - RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED_REMOTE); + RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED); // Insert a copy of the entry into our cache. const auto &parent_ids = entry->GetParentTaskIds(); @@ -214,90 +199,48 @@ void LineageCache::AddUncommittedLineage(const TaskID &task_id, // if the new entry has an equal or lower GCS status than the current entry // in our cache. This also prevents us from traversing the same node twice. if (lineage_.SetEntry(entry->TaskData(), entry->GetStatus())) { - subscribe_tasks.insert(task_id); + RAY_CHECK(SubscribeTask(task_id)); for (const auto &parent_id : parent_ids) { - AddUncommittedLineage(parent_id, uncommitted_lineage, subscribe_tasks); + AddUncommittedLineage(parent_id, uncommitted_lineage); } } } -bool LineageCache::AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage) { - auto task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Add waiting task " << task_id << " on " << client_id_; - - // Merge the uncommitted lineage into the lineage cache. Collect the IDs of - // tasks that we should subscribe to. These are all of the tasks that were - // included in the uncommitted lineage that we did not already have in our - // stash. - std::unordered_set subscribe_tasks; - AddUncommittedLineage(task_id, uncommitted_lineage, subscribe_tasks); - // Add the submitted task to the lineage cache as UNCOMMITTED_WAITING. It - // should be marked as UNCOMMITTED_READY once the task starts execution. - auto added = lineage_.SetEntry(task, GcsStatus::UNCOMMITTED_WAITING); - - // Do not subscribe to the waiting task itself. We just added it as - // UNCOMMITTED_WAITING, so the task is local. - subscribe_tasks.erase(task_id); - // Unsubscribe to the waiting task since we may have previously been - // subscribed to it. - UnsubscribeTask(task_id); - // Subscribe to all other tasks that were included in the uncommitted lineage - // and that were not already in the local stash. These tasks haven't been - // committed yet and will be committed by a different node, so we will not - // evict them until a notification for their commit is received. - for (const auto &task_id : subscribe_tasks) { - RAY_CHECK(SubscribeTask(task_id)); - } - - return added; -} - -bool LineageCache::AddReadyTask(const Task &task) { +bool LineageCache::CommitTask(const Task &task) { const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Add ready task " << task_id << " on " << client_id_; + RAY_LOG(DEBUG) << "Committing task " << task_id << " on " << client_id_; - // Set the task to READY. - if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED_READY)) { - // Attempt to flush the task. + if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED) || + lineage_.GetEntry(task_id)->GetStatus() == GcsStatus::UNCOMMITTED) { + // Attempt to flush the task if the task is uncommitted. FlushTask(task_id); return true; } else { - // The task was already ready to be committed (UNCOMMITTED_READY) or - // committing (COMMITTING). + // The task was already committing (COMMITTING). return false; } } -bool LineageCache::RemoveWaitingTask(const TaskID &task_id) { - RAY_LOG(DEBUG) << "Remove waiting task " << task_id << " on " << client_id_; - auto entry = lineage_.GetEntryMutable(task_id); - if (!entry) { - // The task was already evicted. - return false; - } - - // If the task is already not in WAITING status, then exit. This should only - // happen when there are two copies of the task executing at the node, due to - // a spurious reconstruction. Then, either the task is already past WAITING - // status, in which case it will be committed, or it is in - // UNCOMMITTED_REMOTE, in which case it was already removed. - if (entry->GetStatus() != GcsStatus::UNCOMMITTED_WAITING) { - return false; +void LineageCache::FlushAllUncommittedTasks() { + size_t num_flushed = 0; + for (const auto &entry : lineage_.GetEntries()) { + // Flush all tasks that have not yet committed. + if (entry.second.GetStatus() == GcsStatus::UNCOMMITTED) { + RAY_CHECK(UnsubscribeTask(entry.first)); + FlushTask(entry.first); + num_flushed++; + } } - // Reset the status to REMOTE. We keep the task instead of removing it - // completely in case another task is submitted locally that depends on this - // one. - entry->ResetStatus(GcsStatus::UNCOMMITTED_REMOTE); - // The task is now remote, so subscribe to the task to make sure that we'll - // eventually clean it up. - RAY_CHECK(SubscribeTask(task_id)); - return true; + RAY_LOG(DEBUG) << "Flushed " << num_flushed << " uncommitted tasks"; } void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) { RAY_CHECK(!node_id.IsNil()); - lineage_.GetEntryMutable(task_id)->MarkExplicitlyForwarded(node_id); + auto entry = lineage_.GetEntryMutable(task_id); + if (entry) { + entry->MarkExplicitlyForwarded(node_id); + } } /// A helper function to get the uncommitted lineage of a task. @@ -345,12 +288,11 @@ Lineage LineageCache::GetUncommittedLineageOrDie(const TaskID &task_id, void LineageCache::FlushTask(const TaskID &task_id) { auto entry = lineage_.GetEntryMutable(task_id); RAY_CHECK(entry); - RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED_READY); + RAY_CHECK(entry->GetStatus() < GcsStatus::COMMITTING); - gcs::raylet::TaskTable::WriteCallback task_callback = [this]( - ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { - HandleEntryCommitted(id); - }; + gcs::raylet::TaskTable::WriteCallback task_callback = + [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, + const protocol::TaskT &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... flatbuffers::FlatBufferBuilder fbb; @@ -406,11 +348,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { if (!entry) { return; } - // Only evict tasks that we were subscribed to or that we were committing. - if (!(entry->GetStatus() == GcsStatus::UNCOMMITTED_REMOTE || - entry->GetStatus() == GcsStatus::COMMITTING)) { - return; - } // Entries cannot be safely evicted until their parents are all evicted. for (const auto &parent_id : entry->GetParentTaskIds()) { if (ContainsTask(parent_id)) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 02d98b8cffe6..5436fa372fa4 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -17,19 +17,23 @@ namespace ray { namespace raylet { /// The status of a lineage cache entry according to its status in the GCS. +/// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state +/// can become COMMITTING but not vice versa). If a task is evicted from the +/// local cache, it implicitly goes back to state `NONE`, after which it may be +/// added to the local cache again (e.g., if it is forwarded to us again). enum class GcsStatus { /// The task is not in the lineage cache. NONE = 0, - /// The task is being executed or created on a remote node. - UNCOMMITTED_REMOTE, - /// The task is waiting to be executed or created locally. - UNCOMMITTED_WAITING, - /// The task has started execution, but the entry has not been written to the - /// GCS yet. - UNCOMMITTED_READY, - /// The task has been written to the GCS and we are waiting for an - /// acknowledgement of the commit. + /// The task is uncommitted. Unless there is a failure, we will expect a + /// different node to commit this task. + UNCOMMITTED, + /// We flushed this task and are waiting for the commit acknowledgement. COMMITTING, + // TODO(swang): Add a COMMITTED state for tasks for which we received a + // commit acknowledgement, but which we cannot evict yet (due to an ancestor + // that has not been evicted). This is to allow a performance optimization + // that avoids unnecessary subscribes when we receive tasks that were + // already COMMITTED at the sender. }; /// \class LineageEntry @@ -220,37 +224,30 @@ class LineageCache { gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); - /// Add a task that is waiting for execution and its uncommitted lineage. - /// These entries will not be written to the GCS until set to ready. + /// Asynchronously commit a task to the GCS. /// - /// \param task The waiting task to add. - /// \param uncommitted_lineage The task's uncommitted lineage. These are the - /// tasks that the given task is data-dependent on, but that have not - /// been made durable in the GCS, as far the task's submitter knows. - /// \return Whether the task was successfully marked as waiting to be - /// committed. This will return false if the task is already waiting to be - /// committed (UNCOMMITTED_WAITING), ready to be committed - /// (UNCOMMITTED_READY), or committing (COMMITTING). - bool AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage); - - /// Add a task that is ready for GCS writeback. This overwrites the task’s - /// mutable fields in the execution specification. + /// \param task The task to commit. It will be moved to the COMMITTING state. + /// \return Whether the task was successfully committed. This can fail if the + /// task was already in the COMMITTING state. + bool CommitTask(const Task &task); + + /// Flush all tasks in the local cache that are not already being + /// committed. This is equivalent to all tasks in the UNCOMMITTED + /// state. /// - /// \param task The task to set as ready. - /// \return Whether the task was successfully marked as ready to be - /// committed. This will return false if the task is already ready to be - /// committed (UNCOMMITTED_READY) or committing (COMMITTING). - bool AddReadyTask(const Task &task); - - /// Remove a task that was waiting for execution. Its uncommitted lineage - /// will remain unchanged. + /// \return Void. + void FlushAllUncommittedTasks(); + + /// Add a task and its (estimated) uncommitted lineage to the local cache. We + /// will subscribe to commit notifications for all uncommitted tasks to + /// determine when it is safe to evict the lineage from the local cache. /// - /// \param task_id The ID of the waiting task to remove. - /// \return Whether the task was successfully removed. This will return false - /// if the task is not waiting to be committed. Then, the waiting task has - /// already been removed (UNCOMMITTED_REMOTE), or if it's ready to be - /// committed (UNCOMMITTED_READY) or committing (COMMITTING). - bool RemoveWaitingTask(const TaskID &task_id); + /// \param task_id The ID of the uncommitted task to add. + /// \param uncommitted_lineage The task's uncommitted lineage. These are the + /// tasks that the given task is data-dependent on, but that have not + /// been committed to the GCS. This must contain the given task ID. + /// \return Void. + void AddUncommittedLineage(const TaskID &task_id, const Lineage &uncommitted_lineage); /// Mark a task as having been explicitly forwarded to a node. /// The lineage of the task is implicitly assumed to have also been forwarded. @@ -317,9 +314,6 @@ class LineageCache { /// Unsubscribe from notifications for a task. Returns whether the operation /// was successful (whether we were subscribed). bool UnsubscribeTask(const TaskID &task_id); - /// Add a task and its uncommitted lineage to the local stash. - void AddUncommittedLineage(const TaskID &task_id, const Lineage &uncommitted_lineage, - std::unordered_set &subscribe_tasks); /// The client ID, used to request notifications for specific tasks. /// TODO(swang): Move the ClientID into the generic Table implementation. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index a61ae846a925..43e64e400292 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -26,8 +26,22 @@ class MockGcs : public gcs::TableInterface, std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; + auto callback = done; + // If we requested notifications for this task ID, send the notification as + // part of the callback. + if (subscribed_tasks_.count(task_id) == 1) { + callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + const protocol::TaskT &data) { + done(client, task_id, data); + // If we're subscribed to the task to be added, also send a + // subscription notification. + notification_callback_(client, task_id, data); + }; + } + callbacks_.push_back( - std::pair(done, task_id)); + std::pair(callback, task_id)); + num_task_adds_++; return ray::Status::OK(); } @@ -78,28 +92,34 @@ class MockGcs : public gcs::TableInterface, const int NumRequestedNotifications() const { return num_requested_notifications_; } + const int NumTaskAdds() const { return num_task_adds_; } + private: std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; int num_requested_notifications_ = 0; + int num_task_adds_ = 0; }; class LineageCacheTest : public ::testing::Test { public: LineageCacheTest() : max_lineage_size_(10), + num_notifications_(0), mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const ray::protocol::TaskT &data) { lineage_cache_.HandleEntryCommitted(task_id); + num_notifications_++; }); } protected: uint64_t max_lineage_size_; + uint64_t num_notifications_; MockGcs mock_gcs_; LineageCache lineage_cache_; }; @@ -122,15 +142,22 @@ static inline Task ExampleTask(const std::vector &arguments, return task; } +/// Helper method to create a Lineage object with a single task. +Lineage CreateSingletonLineage(const Task &task) { + Lineage singleton_lineage; + singleton_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); + return singleton_lineage; +} + std::vector InsertTaskChain(LineageCache &lineage_cache, std::vector &inserted_tasks, int chain_size, const std::vector &initial_arguments, int64_t num_returns) { - Lineage empty_lineage; std::vector arguments = initial_arguments; for (int i = 0; i < chain_size; i++) { auto task = ExampleTask(arguments, num_returns); - RAY_CHECK(lineage_cache.AddWaitingTask(task, empty_lineage)); + Lineage lineage = CreateSingletonLineage(task); + lineage_cache.AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage); inserted_tasks.push_back(task); arguments.clear(); for (int j = 0; j < task.GetTaskSpecification().NumReturns(); j++) { @@ -190,6 +217,34 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineageOrDie) { } } +TEST_F(LineageCacheTest, TestDuplicateUncommittedLineage) { + // Insert a chain of tasks. + std::vector tasks; + auto return_values = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + std::vector task_ids; + for (const auto &task : tasks) { + task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + // Check that we subscribed to each of the uncommitted tasks. + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Check that if we add the same tasks as UNCOMMITTED again, we do not issue + // duplicate subscribe requests. + Lineage duplicate_lineage; + for (const auto &task : tasks) { + duplicate_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); + } + lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage); + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Check that if we commit one of the tasks, we still do not issue any + // duplicate subscribe requests. + lineage_cache_.CommitTask(tasks.front()); + lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage); + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); +} + TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { // Insert chain of tasks. std::vector tasks; @@ -222,7 +277,7 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { ASSERT_EQ(1, uncommitted_lineage_forwarded.GetEntries().size()); } -TEST_F(LineageCacheTest, TestWritebackNoneReady) { +TEST_F(LineageCacheTest, TestWritebackReady) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; @@ -231,16 +286,9 @@ TEST_F(LineageCacheTest, TestWritebackNoneReady) { // Check that when no tasks have been marked as ready, we do not flush any // entries. ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); -} - -TEST_F(LineageCacheTest, TestWritebackReady) { - // Insert a chain of dependent tasks. - size_t num_tasks_flushed = 0; - std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); // Check that after marking the first task as ready, we flush only that task. - ASSERT_TRUE(lineage_cache_.AddReadyTask(tasks.front())); + ASSERT_TRUE(lineage_cache_.CommitTask(tasks.front())); num_tasks_flushed++; ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); } @@ -253,7 +301,7 @@ TEST_F(LineageCacheTest, TestWritebackOrder) { // Mark all tasks as ready. All tasks should be flushed. for (const auto &task : tasks) { - ASSERT_TRUE(lineage_cache_.AddReadyTask(task)); + ASSERT_TRUE(lineage_cache_.CommitTask(task)); } ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); @@ -272,12 +320,13 @@ TEST_F(LineageCacheTest, TestEvictChain) { Lineage uncommitted_lineage; for (const auto &task : tasks) { - uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED_REMOTE); + uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); } // Mark the last task as ready to flush. - ASSERT_TRUE(lineage_cache_.AddWaitingTask(tasks.back(), uncommitted_lineage)); + lineage_cache_.AddUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), + uncommitted_lineage); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); - ASSERT_TRUE(lineage_cache_.AddReadyTask(tasks.back())); + ASSERT_TRUE(lineage_cache_.CommitTask(tasks.back())); num_tasks_flushed++; ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); // Flush acknowledgements. The lineage cache should receive the commit for @@ -320,17 +369,20 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { auto task = ExampleTask({}, 1); parent_tasks.push_back(task); arguments.push_back(task.GetTaskSpecification().ReturnId(0)); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(task, Lineage())); + auto lineage = CreateSingletonLineage(task); + lineage_cache_.AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage); } // Create a child task that is dependent on all of the previous tasks. auto child_task = ExampleTask(arguments, 1); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(child_task, Lineage())); + auto lineage = CreateSingletonLineage(child_task); + lineage_cache_.AddUncommittedLineage(child_task.GetTaskSpecification().TaskId(), + lineage); // Flush the child task. Make sure that it remains in the cache, since none // of its parents have been committed yet, and that the uncommitted lineage // still includes all of the parent tasks. size_t total_tasks = parent_tasks.size() + 1; - lineage_cache_.AddReadyTask(child_task); + lineage_cache_.CommitTask(child_task); mock_gcs_.Flush(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks); ASSERT_EQ(lineage_cache_ @@ -342,7 +394,7 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { // Flush each parent task and check for eviction safety. for (const auto &parent_task : parent_tasks) { - lineage_cache_.AddReadyTask(parent_task); + lineage_cache_.CommitTask(parent_task); mock_gcs_.Flush(); total_tasks--; if (total_tasks > 1) { @@ -364,75 +416,6 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); } -TEST_F(LineageCacheTest, TestForwardTasksRoundTrip) { - // Insert a chain of dependent tasks. - uint64_t lineage_size = max_lineage_size_ + 1; - std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - - // Simulate removing each task, forwarding it to another node, then - // receiving the task back again. - for (auto it = tasks.begin(); it != tasks.end(); it++) { - const auto task_id = it->GetTaskSpecification().TaskId(); - // Simulate removing the task and forwarding it to another node. - auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_id, ClientID::Nil()); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - // Simulate receiving the task again. Make sure we can add the task back. - flatbuffers::FlatBufferBuilder fbb; - auto uncommitted_lineage_message = uncommitted_lineage.ToFlatbuffer(fbb, task_id); - fbb.Finish(uncommitted_lineage_message); - uncommitted_lineage = Lineage( - *flatbuffers::GetRoot(fbb.GetBufferPointer())); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(*it, uncommitted_lineage)); - } -} - -TEST_F(LineageCacheTest, TestForwardTask) { - // Insert a chain of dependent tasks. - size_t num_tasks_flushed = 0; - std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); - - // Simulate removing the task and forwarding it to another node. - auto it = tasks.begin() + 1; - auto forwarded_task = *it; - tasks.erase(it); - auto task_id_to_remove = forwarded_task.GetTaskSpecification().TaskId(); - auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_id_to_remove, ClientID::Nil()); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id_to_remove)); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 3); - - // Simulate executing the remaining tasks. - for (const auto &task : tasks) { - ASSERT_TRUE(lineage_cache_.AddReadyTask(task)); - num_tasks_flushed++; - } - // Check that the first task, which has no dependencies can be flushed. The - // last task cannot be flushed since one of its dependencies has not been - // added by the remote node yet. - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - mock_gcs_.Flush(); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 2); - - // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); - RAY_CHECK_OK( - mock_gcs_.RemoteAdd(forwarded_task.GetTaskSpecification().TaskId(), task_data)); - // Check that the remote task is flushed. - num_tasks_flushed++; - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - ASSERT_EQ(mock_gcs_.SubscribedTasks().size(), 1); - - // Check that once we receive the callback for the remote task, we can now - // flush the last task. - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.SubscribedTasks().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); -} - TEST_F(LineageCacheTest, TestEviction) { // Insert a chain of dependent tasks. uint64_t lineage_size = max_lineage_size_ + 1; @@ -440,12 +423,6 @@ TEST_F(LineageCacheTest, TestEviction) { std::vector tasks; InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - // Simulate forwarding the chain of tasks to a remote node. - for (const auto &task : tasks) { - auto task_id = task.GetTaskSpecification().TaskId(); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - } - // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); @@ -500,12 +477,6 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { std::vector tasks; InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - // Simulate forwarding the chain of tasks to a remote node. - for (const auto &task : tasks) { - auto task_id = task.GetTaskSpecification().TaskId(); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - } - // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); @@ -545,19 +516,15 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { std::vector tasks; InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - // Simulate forwarding the chain of tasks to a remote node. - for (const auto &task : tasks) { - auto task_id = task.GetTaskSpecification().TaskId(); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - } - // Add more tasks to the lineage cache that will remain local. Each of these // tasks is dependent one of the tasks that was forwarded above. for (const auto &task : tasks) { auto return_id = task.GetTaskSpecification().ReturnId(0); auto dependent_task = ExampleTask({return_id}, 1); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(dependent_task, Lineage())); - ASSERT_TRUE(lineage_cache_.AddReadyTask(dependent_task)); + auto lineage = CreateSingletonLineage(dependent_task); + lineage_cache_.AddUncommittedLineage(dependent_task.GetTaskSpecification().TaskId(), + lineage); + ASSERT_TRUE(lineage_cache_.CommitTask(dependent_task)); // Once the forwarded tasks are evicted from the lineage cache, we expect // each of these dependent tasks to be flushed, since all of their // dependencies have been committed. @@ -582,6 +549,39 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); } +TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) { + // Insert a chain of tasks. + std::vector tasks; + auto return_values = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + std::vector task_ids; + for (const auto &task : tasks) { + task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + // Check that we subscribed to each of the uncommitted tasks. + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Flush all uncommitted tasks and make sure we add all tasks to + // the task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + // Flush again and make sure there are no new tasks added to the + // task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + + // Flush all GCS notifications. + mock_gcs_.Flush(); + // Make sure that we unsubscribed to the uncommitted tasks before + // we flushed them. + ASSERT_EQ(num_notifications_, 0); + + // Flush again and make sure there are no new tasks added to the + // task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 003e48370dbf..eca282a53309 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -22,6 +22,7 @@ DEFINE_string(python_worker_command, "", "Python worker command."); DEFINE_string(java_worker_command, "", "Java worker command."); DEFINE_string(redis_password, "", "The password of redis."); DEFINE_string(temp_dir, "", "Temporary directory."); +DEFINE_string(session_dir, "", "The path of this ray session directory."); DEFINE_bool(disable_stats, false, "Whether disable the stats."); DEFINE_string(stat_address, "127.0.0.1:8888", "The address that we report metrics to."); DEFINE_bool(enable_stdout_exporter, false, @@ -61,6 +62,7 @@ int main(int argc, char *argv[]) { const std::string java_worker_command = FLAGS_java_worker_command; const std::string redis_password = FLAGS_redis_password; const std::string temp_dir = FLAGS_temp_dir; + const std::string session_dir = FLAGS_session_dir; const bool disable_stats = FLAGS_disable_stats; const std::string stat_address = FLAGS_stat_address; const bool enable_stdout_exporter = FLAGS_enable_stdout_exporter; @@ -69,7 +71,7 @@ int main(int argc, char *argv[]) { // Initialize stats. const ray::stats::TagsType global_tags = { {ray::stats::JobNameKey, "raylet"}, - {ray::stats::VersionKey, "0.7.0"}, + {ray::stats::VersionKey, "0.7.1"}, {ray::stats::NodeAddressKey, node_ip_address}}; ray::stats::Init(stat_address, global_tags, disable_stats, enable_stdout_exporter); @@ -132,6 +134,7 @@ int main(int argc, char *argv[]) { node_manager_config.max_lineage_size = RayConfig::instance().max_lineage_size(); node_manager_config.store_socket_name = store_socket_name; node_manager_config.temp_dir = temp_dir; + node_manager_config.session_dir = session_dir; // Configuration for the object manager. ray::ObjectManagerConfig object_manager_config; @@ -171,7 +174,7 @@ int main(int argc, char *argv[]) { // instead of returning immediately. // We should stop the service and remove the local socket file. auto handler = [&main_service, &raylet_socket_name, &server, &gcs_client]( - const boost::system::error_code &error, int signal_number) { + const boost::system::error_code &error, int signal_number) { auto shutdown_callback = [&server, &main_service]() { server.reset(); main_service.stop(); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index a87257cadda4..62ecb00b819f 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -48,8 +48,8 @@ void Monitor::Tick() { auto client_id = it->first; RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( - gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + gcs::AsyncGcsClient *client, const ClientID &id, + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { if (client_id.Binary() == data.client_id && diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e3fd9a0df09f..a0bde1ff0655 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -99,9 +99,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, lineage_cache_(gcs_client_->client_table().GetLocalClientId(), gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), - remote_clients_(), - remote_server_connections_(), - actor_registry_() { + actor_registry_(), + node_manager_server_(config.node_manager_port, io_service, *this), + client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. ClientID local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -117,6 +117,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, [this](const ObjectID &object_id) { HandleObjectMissing(object_id); })); RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); + // Run the node manger rpc server. + node_manager_server_.Run(); } ray::Status NodeManager::RegisterGcs() { @@ -179,45 +181,42 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. - auto node_manager_client_removed = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ClientRemoved(data); - }; + auto node_manager_client_removed = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests - auto node_manager_resource_createupdated = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ResourceCreateUpdated(data); - }; + auto node_manager_resource_createupdated = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests - auto node_manager_resource_deleted = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ResourceDeleted(data); - }; + auto node_manager_resource_deleted = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. - const auto &heartbeat_batch_added = [this]( - gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { - HeartbeatBatchAdded(heartbeat_batch); - }; + const auto &heartbeat_batch_added = + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatBatchTableDataT &heartbeat_batch) { + HeartbeatBatchAdded(heartbeat_batch); + }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( DriverID::Nil(), ClientID::Nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); // Subscribe to driver table updates. - const auto driver_table_handler = [this]( - gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { - HandleDriverTableUpdate(client_id, driver_data); - }; + const auto driver_table_handler = + [this](gcs::AsyncGcsClient *client, const DriverID &client_id, + const std::vector &driver_data) { + HandleDriverTableUpdate(client_id, driver_data); + }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( DriverID::Nil(), ClientID::Nil(), driver_table_handler, nullptr)); @@ -369,66 +368,24 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { return; } - // TODO(atumanov): make remote client lookup O(1) - if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) == - remote_clients_.end()) { - remote_clients_.push_back(client_id); - } else { - // NodeManager connection to this client was already established. - RAY_LOG(DEBUG) << "Received a new client connection that already exists: " + auto entry = remote_node_manager_clients_.find(client_id); + if (entry != remote_node_manager_clients_.end()) { + RAY_LOG(DEBUG) << "Received notification of a new client that already exists: " << client_id; return; } - // Establish a new NodeManager connection to this GCS client. - auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address, - client_data.node_manager_port); - if (!status.ok()) { - // This is not a fatal error for raylet, but it should not happen. - // We need to broadcase this message. - std::string type = "raylet_connection_error"; - std::ostringstream error_message; - error_message << "Failed to connect to ray node " << client_id - << " with status: " << status.ToString() - << ". This may be since the node was recently removed."; - // We use the nil DriverID to broadcast the message to all drivers. - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), type, error_message.str(), current_time_ms())); - return; - } + // Initialize a rpc client to the new node manager. + std::unique_ptr client( + new rpc::NodeManagerClient(client_data.node_manager_address, + client_data.node_manager_port, client_call_manager_)); + remote_node_manager_clients_.emplace(client_id, std::move(client)); ResourceSet resources_total(client_data.resources_total_label, client_data.resources_total_capacity); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id, - const std::string &client_address, - int32_t client_port) { - // Establish a new NodeManager connection to this GCS client. - RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at " - << client_address << ":" << client_port; - - boost::asio::ip::tcp::socket socket(io_service_); - RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port)); - - // The client is connected, now send a connect message to remote node manager. - auto server_conn = TcpServerConnection::Create(std::move(socket)); - - // Prepare client connection info buffer - flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_)); - fbb.Finish(message); - // Send synchronously. - // TODO(swang): Make this a WriteMessageAsync. - RAY_RETURN_NOT_OK(server_conn->WriteMessage( - static_cast(protocol::MessageType::ConnectClient), fbb.GetSize(), - fbb.GetBufferPointer())); - - remote_server_connections_.emplace(client_id, std::move(server_conn)); - return ray::Status::OK(); -} - void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. @@ -443,17 +400,13 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // check that it is actually removed, or log a warning otherwise, but that may // not be necessary. - // Remove the client from the list of remote clients. - std::remove(remote_clients_.begin(), remote_clients_.end(), client_id); - // Remove the client from the resource map. cluster_resource_map_.erase(client_id); - // Remove the remote server connection. - const auto connection_entry = remote_server_connections_.find(client_id); - if (connection_entry != remote_server_connections_.end()) { - connection_entry->second->Close(); - remote_server_connections_.erase(connection_entry); + // Remove the node manager client. + const auto client_entry = remote_node_manager_clients_.find(client_id); + if (client_entry != remote_node_manager_clients_.end()) { + remote_node_manager_clients_.erase(client_entry); } else { RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client " << client_id << "."; @@ -475,6 +428,11 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // Notify the object directory that the client has been removed so that it // can remove it from any cached locations. object_directory_->HandleClientRemoved(client_id); + + // Flush all uncommitted tasks from the local lineage cache. This is to + // guarantee that all tasks get flushed eventually, in case one of the tasks + // in our local cache was supposed to be flushed by the node that died. + lineage_cache_.FlushAllUncommittedTasks(); } void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { @@ -693,11 +651,6 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // known. auto created_actor_methods = local_queues_.RemoveTasks(created_actor_method_ids); for (const auto &method : created_actor_methods) { - if (!lineage_cache_.RemoveWaitingTask(method.GetTaskSpecification().TaskId())) { - RAY_LOG(WARNING) << "Task " << method.GetTaskSpecification().TaskId() - << " already removed from the lineage cache. This is most " - "likely due to reconstruction."; - } // Maintain the invariant that if a task is in the // MethodsWaitingForActorCreation queue, then it is subscribed to its // respective actor creation task. Since the actor location is now known, @@ -1244,41 +1197,24 @@ void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client node_manager_client.ProcessMessages(); } -void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, - int64_t message_type, - const uint8_t *message_data) { - const auto message_type_value = static_cast(message_type); - RAY_LOG(DEBUG) << "[NodeManager] Message " - << protocol::EnumNameMessageType(message_type_value) << "(" - << message_type << ") from node manager"; - switch (message_type_value) { - case protocol::MessageType::ConnectClient: { - auto message = flatbuffers::GetRoot(message_data); - auto client_id = from_flatbuf(*message->client_id()); - node_manager_client.SetClientID(client_id); - } break; - case protocol::MessageType::ForwardTaskRequest: { - auto message = flatbuffers::GetRoot(message_data); - TaskID task_id = from_flatbuf(*message->task_id()); - - Lineage uncommitted_lineage(*message); - const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); - RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() - << " on node " << gcs_client_->client_table().GetLocalClientId() - << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); - SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); - } break; - case protocol::MessageType::DisconnectClient: { - // TODO(rkn): We need to do some cleanup here. - RAY_LOG(DEBUG) << "Received disconnect message from remote node manager. " - << "We need to do some cleanup here."; - // Do not process any more messages from this node manager. - return; - } break; - default: - RAY_LOG(FATAL) << "Received unexpected message type " << message_type; - } - node_manager_client.ProcessMessages(); +void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::RequestDoneCallback done_callback) { + // Get the forwarded task and its uncommitted lineage from the request. + TaskID task_id = TaskID::FromBinary(request.task_id()); + Lineage uncommitted_lineage; + for (int i = 0; i < request.uncommitted_tasks_size(); i++) { + const std::string &task_message = request.uncommitted_tasks(i); + const Task task(*flatbuffers::GetRoot( + reinterpret_cast(task_message.data()))); + RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED)); + } + const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); + RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() + << " on node " << gcs_client_->client_table().GetLocalClientId() + << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); + SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); + done_callback(Status::OK()); } void NodeManager::ProcessSetResourceRequest( @@ -1466,10 +1402,6 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ current_time_ms())); } } - // A task failing is equivalent to assigning and finishing the task, so clean - // up any leftover state as for any task dispatched and removed from the - // local queue. - lineage_cache_.AddReadyTask(task); task_dependency_manager_.TaskCanceled(spec.TaskId()); // Notify the task dependency manager that we no longer need this task's // object dependencies. TODO(swang): Ideally, we would check the return value @@ -1538,10 +1470,14 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } // Add the task and its uncommitted lineage to the lineage cache. - if (!lineage_cache_.AddWaitingTask(task, uncommitted_lineage)) { - RAY_LOG(WARNING) - << "Task " << task_id - << " already in lineage cache. This is most likely due to reconstruction."; + if (forwarded) { + lineage_cache_.AddUncommittedLineage(task_id, uncommitted_lineage); + } else { + if (!lineage_cache_.CommitTask(task)) { + RAY_LOG(WARNING) + << "Task " << task_id + << " already committed to the GCS. This is most likely due to reconstruction."; + } } if (spec.IsActorTask()) { @@ -1814,9 +1750,9 @@ bool NodeManager::AssignTask(const Task &task) { cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources()); if (spec.IsActorCreationTask()) { - // Check that we are not placing an actor creation task on a node with 0 CPUs. - RAY_CHECK(cluster_resource_map_[my_client_id].GetTotalResources().GetResourceMap().at( - kCPU_ResourceLabel) != 0); + // Check that the actor's placement resource requirements are satisfied. + RAY_CHECK(spec.GetRequiredPlacementResources().IsSubset( + cluster_resource_map_[my_client_id].GetTotalResources())); worker->SetLifetimeResourceIds(acquired_resources); } else { worker->SetTaskResourceIds(acquired_resources); @@ -1869,32 +1805,14 @@ bool NodeManager::AssignTask(const Task &task) { actor_entry->second.AddHandle(new_handle_id, execution_dependency); } - // If the task was an actor task, then record this execution to - // guarantee consistency in the case of reconstruction. - auto execution_dependency = actor_entry->second.GetExecutionDependency(); - // The execution dependency is initialized to the actor creation task's - // return value, and is subsequently updated to the assigned tasks' - // return values, so it should never be nil. - RAY_CHECK(!execution_dependency.IsNil()); - // Update the task's execution dependencies to reflect the actual - // execution order, to support deterministic reconstruction. - // NOTE(swang): The update of an actor task's execution dependencies is - // performed asynchronously. This means that if this node manager dies, - // we may lose updates that are in flight to the task table. We only - // guarantee deterministic reconstruction ordering for tasks whose - // updates are reflected in the task table. - // (SetExecutionDependencies takes a non-const so copy task in a - // on-const variable.) - assigned_task.SetExecutionDependencies({execution_dependency}); + // TODO(swang): For actors with multiple actor handles, to + // guarantee that tasks are replayed in the same order after a + // failure, we must update the task's execution dependency to be + // the actor's current execution dependency. } else { RAY_CHECK(spec.NewActorHandles().empty()); } - // We started running the task, so the task is ready to write to GCS. - if (!lineage_cache_.AddReadyTask(assigned_task)) { - RAY_LOG(WARNING) << "Task " << spec.TaskId() << " already in lineage cache." - << " This is most likely due to reconstruction."; - } // Mark the task as running. // (See design_docs/task_states.rst for the state transition diagram.) local_queues_.QueueTasks({assigned_task}, TaskState::RUNNING); @@ -2166,7 +2084,7 @@ void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // Notify the task dependency manager that this object is local. const auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(object_id); RAY_LOG(DEBUG) << "Object local " << object_id << ", " - << " on " << gcs_client_->client_table().GetLocalClientId() + << " on " << gcs_client_->client_table().GetLocalClientId() << ", " << ready_task_ids.size() << " tasks ready"; // Transition the tasks whose dependencies are now fulfilled to the ready state. if (ready_task_ids.size() > 0) { @@ -2220,61 +2138,70 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, /// TODO(rkn): Should we check that the node manager is remote and not local? /// TODO(rkn): Should we check if the remote node manager is known to be dead? // Attempt to forward the task. - ForwardTask(task, node_manager_id, [this, node_manager_id](ray::Status error, - const Task &task) { - const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " - << node_manager_id; - - // Mark the failed task as pending to let other raylets know that we still - // have the task. TaskDependencyManager::TaskPending() is assumed to be - // idempotent. - task_dependency_manager_.TaskPending(task); - - // Actor tasks can only be executed at the actor's location, so they are - // retried after a timeout. All other tasks that fail to be forwarded are - // deemed to be placeable again. - if (task.GetTaskSpecification().IsActorTask()) { - // The task is for an actor on another node. Create a timer to resubmit - // the task in a little bit. TODO(rkn): Really this should be a - // unique_ptr instead of a shared_ptr. However, it's a little harder to - // move unique_ptrs into lambdas. - auto retry_timer = std::make_shared(io_service_); - auto retry_duration = boost::posix_time::milliseconds( - RayConfig::instance().node_manager_forward_task_retry_timeout_milliseconds()); - retry_timer->expires_from_now(retry_duration); - retry_timer->async_wait( - [this, task_id, retry_timer](const boost::system::error_code &error) { - // Timer killing will receive the boost::asio::error::operation_aborted, - // we only handle the timeout event. - RAY_CHECK(!error); - RAY_LOG(INFO) << "Resubmitting task " << task_id - << " because ForwardTask failed."; - // Remove the RESUBMITTED task from the SWAP queue. - TaskState state; - const auto task = local_queues_.RemoveTask(task_id, &state); - RAY_CHECK(state == TaskState::SWAP); - // Submit the task again. - SubmitTask(task, Lineage()); - }); - // Temporarily move the RESUBMITTED task to the SWAP queue while the - // timer is active. - local_queues_.QueueTasks({task}, TaskState::SWAP); - // Remove the task from the lineage cache. The task will get added back - // once it is resubmitted. - lineage_cache_.RemoveWaitingTask(task_id); - } else { - // The task is not for an actor and may therefore be placed on another - // node immediately. Send it to the scheduling policy to be placed again. - local_queues_.QueueTasks({task}, TaskState::PLACEABLE); - ScheduleTasks(cluster_resource_map_); - } - }); + ForwardTask( + task, node_manager_id, + [this, node_manager_id](ray::Status error, const Task &task) { + const TaskID task_id = task.GetTaskSpecification().TaskId(); + RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " + << node_manager_id; + + // Mark the failed task as pending to let other raylets know that we still + // have the task. TaskDependencyManager::TaskPending() is assumed to be + // idempotent. + task_dependency_manager_.TaskPending(task); + + // Actor tasks can only be executed at the actor's location, so they are + // retried after a timeout. All other tasks that fail to be forwarded are + // deemed to be placeable again. + if (task.GetTaskSpecification().IsActorTask()) { + // The task is for an actor on another node. Create a timer to resubmit + // the task in a little bit. TODO(rkn): Really this should be a + // unique_ptr instead of a shared_ptr. However, it's a little harder to + // move unique_ptrs into lambdas. + auto retry_timer = std::make_shared(io_service_); + auto retry_duration = boost::posix_time::milliseconds( + RayConfig::instance() + .node_manager_forward_task_retry_timeout_milliseconds()); + retry_timer->expires_from_now(retry_duration); + retry_timer->async_wait( + [this, task_id, retry_timer](const boost::system::error_code &error) { + // Timer killing will receive the boost::asio::error::operation_aborted, + // we only handle the timeout event. + RAY_CHECK(!error); + RAY_LOG(INFO) << "Resubmitting task " << task_id + << " because ForwardTask failed."; + // Remove the RESUBMITTED task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + // Submit the task again. + SubmitTask(task, Lineage()); + }); + // Temporarily move the RESUBMITTED task to the SWAP queue while the + // timer is active. + local_queues_.QueueTasks({task}, TaskState::SWAP); + } else { + // The task is not for an actor and may therefore be placed on another + // node immediately. Send it to the scheduling policy to be placed again. + local_queues_.QueueTasks({task}, TaskState::PLACEABLE); + ScheduleTasks(cluster_resource_map_); + } + }); } void NodeManager::ForwardTask( const Task &task, const ClientID &node_id, const std::function &on_error) { + // Lookup node manager client for this node_id and use it to send the request. + auto client_entry = remote_node_manager_clients_.find(node_id); + if (client_entry == remote_node_manager_clients_.end()) { + // TODO(atumanov): caller must handle failure to ensure tasks are not lost. + RAY_LOG(INFO) << "No node manager client found for GCS client id " << node_id; + on_error(ray::Status::IOError("Node manager client not found"), task); + return; + } + auto &client = client_entry->second; + const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -2294,81 +2221,67 @@ void NodeManager::ForwardTask( // Increment forward count for the forwarded task. lineage_cache_entry_task.IncrementNumForwards(); - flatbuffers::FlatBufferBuilder fbb; - auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id); - fbb.Finish(request); - RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from " << gcs_client_->client_table().GetLocalClientId() << " to " << node_id << " spillback=" << lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards(); - // Lookup remote server connection for this node_id and use it to send the request. - auto it = remote_server_connections_.find(node_id); - if (it == remote_server_connections_.end()) { - // TODO(atumanov): caller must handle failure to ensure tasks are not lost. - RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - on_error(ray::Status::IOError("NodeManager connection not found"), task); - return; + // Prepare the request message. + rpc::ForwardTaskRequest request; + request.set_task_id(task_id.Binary()); + for (auto &entry : uncommitted_lineage.GetEntries()) { + request.add_uncommitted_tasks(entry.second.TaskData().Serialize()); } - auto &server_conn = it->second; + // Move the FORWARDING task to the SWAP queue so that we remember that we // have it queued locally. Once the ForwardTaskRequest has been sent, the // task will get re-queued, depending on whether the message succeeded or // not. local_queues_.QueueTasks({task}, TaskState::SWAP); - server_conn->WriteMessageAsync( - static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) { - // Remove the FORWARDING task from the SWAP queue. - TaskState state; - const auto task = local_queues_.RemoveTask(task_id, &state); - RAY_CHECK(state == TaskState::SWAP); - - if (status.ok()) { - const auto &spec = task.GetTaskSpecification(); - // If we were able to forward the task, remove the forwarded task from the - // lineage cache since the receiving node is now responsible for writing - // the task to the GCS. - if (!lineage_cache_.RemoveWaitingTask(task_id)) { - RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage" - << " cache. This is most likely due to reconstruction."; - } else { - // Mark as forwarded so that the task and its lineage is not - // re-forwarded in the future to the receiving node. - lineage_cache_.MarkTaskAsForwarded(task_id, node_id); - } - - // Notify the task dependency manager that we are no longer responsible - // for executing this task. - task_dependency_manager_.TaskCanceled(task_id); - // Preemptively push any local arguments to the receiving node. For now, we - // only do this with actor tasks, since actor tasks must be executed by a - // specific process and therefore have affinity to the receiving node. - if (spec.IsActorTask()) { - // Iterate through the object's arguments. NOTE(swang): We do not include - // the execution dependencies here since those cannot be transferred - // between nodes. - for (int i = 0; i < spec.NumArgs(); ++i) { - int count = spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - ObjectID argument_id = spec.ArgId(i, j); - // If the argument is local, then push it to the receiving node. - if (task_dependency_manager_.CheckObjectLocal(argument_id)) { - object_manager_.Push(argument_id, node_id); - } - } + client->ForwardTask(request, [this, on_error, task_id, node_id]( + Status status, const rpc::ForwardTaskReply &reply) { + // Remove the FORWARDING task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + + if (status.ok()) { + const auto &spec = task.GetTaskSpecification(); + // Mark as forwarded so that the task and its lineage are not + // re-forwarded in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); + + // Notify the task dependency manager that we are no longer responsible + // for executing this task. + task_dependency_manager_.TaskCanceled(task_id); + // Preemptively push any local arguments to the receiving node. For now, we + // only do this with actor tasks, since actor tasks must be executed by a + // specific process and therefore have affinity to the receiving node. + if (spec.IsActorTask()) { + // Iterate through the object's arguments. NOTE(swang): We do not include + // the execution dependencies here since those cannot be transferred + // between nodes. + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + for (int j = 0; j < count; j++) { + ObjectID argument_id = spec.ArgId(i, j); + // If the argument is local, then push it to the receiving node. + if (task_dependency_manager_.CheckObjectLocal(argument_id)) { + object_manager_.Push(argument_id, node_id); } } - } else { - on_error(status, task); } - }); + } + } else { + on_error(status, task); + } + }); } void NodeManager::DumpDebugState() const { std::fstream fs; - fs.open(temp_dir_ + "/debug_state.txt", std::fstream::out | std::fstream::trunc); + fs.open(initial_config_.session_dir + "/debug_state.txt", + std::fstream::out | std::fstream::trunc); fs << DebugString(); fs.close(); } @@ -2397,10 +2310,11 @@ std::string NodeManager::DebugString() const { result << "\n- num dead actors: " << statistical_data.dead_actors; result << "\n- max num handles: " << statistical_data.max_num_handles; - result << "\nRemoteConnections:"; - for (auto &pair : remote_server_connections_) { - result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString(); + result << "\nRemote node manager clients: "; + for (const auto &entry : remote_node_manager_clients_) { + result << "\n" << entry.first; } + result << "\nDebugString() time ms: " << (current_time_ms() - now_ms); return result.str(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 576ffbc23f72..61613358330c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -4,6 +4,9 @@ #include // clang-format off +#include "ray/rpc/client_call.h" +#include "ray/rpc/node_manager_server.h" +#include "ray/rpc/node_manager_client.h" #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" @@ -48,9 +51,11 @@ struct NodeManagerConfig { std::string store_socket_name; /// The path to the ray temp dir. std::string temp_dir; + /// The path of this ray session dir. + std::string session_dir; }; -class NodeManager { +class NodeManager : public rpc::NodeManagerServiceHandler { public: /// Create a node manager. /// @@ -84,15 +89,6 @@ class NodeManager { /// \return Void. void ProcessNewNodeManager(TcpClientConnection &node_manager_client); - /// Handle a message from a remote node manager. - /// - /// \param node_manager_client The connection to the remote node manager. - /// \param message_type The type of the message. - /// \param message The message contents. - /// \return Void. - void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, - int64_t message_type, const uint8_t *message); - /// Subscribe to the relevant GCS tables and set up handlers. /// /// \return Status indicating whether this was done successfully or not. @@ -106,6 +102,9 @@ class NodeManager { /// Record metrics. void RecordMetrics() const; + /// Get the port of the node manager rpc server. + int GetServerPort() const { return node_manager_server_.GetPort(); } + private: /// Methods for handling clients. @@ -448,15 +447,10 @@ class NodeManager { void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, bool intentional_disconnect); - /// connect to a remote node manager. - /// - /// \param client_id The client ID for the remote node manager. - /// \param client_address The IP address for the remote node manager. - /// \param client_port The listening port for the remote node manager. - /// \return True if the connect succeeds. - ray::Status ConnectRemoteNodeManager(const ClientID &client_id, - const std::string &client_address, - int32_t client_port); + /// Handle a `ForwardTask` request. + void HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::RequestDoneCallback done_callback) override; // GCS client ID for this node. ClientID client_id_; @@ -503,9 +497,6 @@ class NodeManager { TaskDependencyManager task_dependency_manager_; /// The lineage cache for the GCS object and task tables. LineageCache lineage_cache_; - std::vector remote_clients_; - std::unordered_map> - remote_server_connections_; /// A mapping from actor ID to registration information about that actor /// (including which node manager owns it). std::unordered_map actor_registry_; @@ -513,6 +504,16 @@ class NodeManager { /// This map stores actor ID to the ID of the checkpoint that will be used to /// restore the actor. std::unordered_map checkpoint_id_to_restore_; + + /// The RPC server. + rpc::NodeManagerServer node_manager_server_; + + /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. + rpc::ClientCallManager client_call_manager_; + + /// Map from node ids to clients of the remote node managers. + std::unordered_map> + remote_node_manager_clients_; }; } // namespace raylet diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 1b043ca58c2b..0f411e8c581d 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -136,16 +136,17 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { void WaitConnections() { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback([this]( - gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { - num_connected_clients += 1; - } - if (num_connected_clients == 2) { - StartTests(); - } - }); + gcs_client_1->client_table().RegisterClientAddedCallback( + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); } void StartTests() { diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index dd9e5fac318e..473e6c263ffe 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -61,15 +61,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), object_manager_config.object_manager_port)), - object_manager_socket_(main_service), - node_manager_acceptor_(main_service, boost::asio::ip::tcp::endpoint( - boost::asio::ip::tcp::v4(), - node_manager_config.node_manager_port)), - node_manager_socket_(main_service) { + object_manager_socket_(main_service) { // Start listening for clients. DoAccept(); DoAcceptObjectManager(); - DoAcceptNodeManager(); RAY_CHECK_OK(RegisterGcs( node_ip_address, socket_name_, object_manager_config.store_socket_name, @@ -100,7 +95,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, client_info.raylet_socket_name = raylet_socket_name; client_info.object_store_socket_name = object_store_socket_name; client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port(); + client_info.node_manager_port = node_manager_.GetServerPort(); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { client_info.resources_total_label.push_back(resource_pair.first); @@ -120,31 +115,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, return Status::OK(); } -void Raylet::DoAcceptNodeManager() { - node_manager_acceptor_.async_accept(node_manager_socket_, - boost::bind(&Raylet::HandleAcceptNodeManager, this, - boost::asio::placeholders::error)); -} - -void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) { - if (!error) { - ClientHandler client_handler = [this]( - TcpClientConnection &client) { node_manager_.ProcessNewNodeManager(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessNodeManagerMessage(*client, message_type, message); - }; - // Accept a new TCP client and dispatch it to the node manager. - auto new_connection = TcpClientConnection::Create( - client_handler, message_handler, std::move(node_manager_socket_), "node manager", - node_manager_message_enum, - static_cast(protocol::MessageType::DisconnectClient)); - } - // We're ready to accept another client. - DoAcceptNodeManager(); -} - void Raylet::DoAcceptObjectManager() { object_manager_acceptor_.async_accept( object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this, @@ -154,11 +124,11 @@ void Raylet::DoAcceptObjectManager() { void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - object_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new TCP client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(object_manager_socket_), @@ -177,11 +147,11 @@ void Raylet::HandleAccept(const boost::system::error_code &error) { // TODO: typedef these handlers. ClientHandler client_handler = [this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. auto new_connection = LocalClientConnection::Create( client_handler, message_handler, std::move(socket_), "worker", diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 84274ea6ecfe..26fe74b2b622 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -63,8 +63,6 @@ class Raylet { void DoAcceptObjectManager(); /// Handle an accepted tcp client connection. void HandleAcceptObjectManager(const boost::system::error_code &error); - void DoAcceptNodeManager(); - void HandleAcceptNodeManager(const boost::system::error_code &error); friend class TestObjectManagerIntegration; @@ -88,10 +86,6 @@ class Raylet { boost::asio::ip::tcp::acceptor object_manager_acceptor_; /// The socket to listen on for new object manager tcp clients. boost::asio::ip::tcp::socket object_manager_socket_; - /// An acceptor for new tcp clients. - boost::asio::ip::tcp::acceptor node_manager_acceptor_; - /// The socket to listen on for new tcp clients. - boost::asio::ip::tcp::socket node_manager_socket_; }; } // namespace raylet diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 85295e403769..73f0e2ef803a 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -233,8 +233,13 @@ std::vector SchedulingQueue::RemoveTasks(std::unordered_set &task_ std::vector removed_tasks; // Try to find the tasks to remove from the queues. for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, + TaskState::PLACEABLE, + TaskState::WAITING, + TaskState::READY, + TaskState::RUNNING, + TaskState::INFEASIBLE, + TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_ids, &removed_tasks); } @@ -248,8 +253,13 @@ Task SchedulingQueue::RemoveTask(const TaskID &task_id, TaskState *removed_task_ std::unordered_set task_id_set = {task_id}; // Try to find the task to remove in the queues. for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, + TaskState::PLACEABLE, + TaskState::WAITING, + TaskState::READY, + TaskState::RUNNING, + TaskState::INFEASIBLE, + TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_id_set, &removed_tasks); if (task_id_set.empty()) { diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 4fd07e5ca606..465f2a4341a0 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -149,9 +149,13 @@ class SchedulingQueue { /// Create a scheduling queue. SchedulingQueue() : ready_queue_(std::make_shared()) { for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, - TaskState::RUNNING, TaskState::INFEASIBLE, - TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, + TaskState::PLACEABLE, + TaskState::WAITING, + TaskState::READY, + TaskState::RUNNING, + TaskState::INFEASIBLE, + TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::SWAP, }) { if (task_state == TaskState::READY) { task_queues_[static_cast(task_state)] = ready_queue_; diff --git a/src/ray/raylet/scheduling_resources.cc b/src/ray/raylet/scheduling_resources.cc index 895535a9a7f0..c80282601256 100644 --- a/src/ray/raylet/scheduling_resources.cc +++ b/src/ray/raylet/scheduling_resources.cc @@ -15,10 +15,11 @@ FractionalResourceQuantity::FractionalResourceQuantity(double resource_quantity) // We check for nonnegativeity due to the implicit conversion to // FractionalResourceQuantity from ints/doubles when we do logical // comparisons. - RAY_CHECK(resource_quantity >= 0) << "Resource capacity, " << resource_quantity - << ", should be nonnegative."; + RAY_CHECK(resource_quantity >= 0) + << "Resource capacity, " << resource_quantity << ", should be nonnegative."; - resource_quantity_ = static_cast(resource_quantity * kResourceConversionFactor); + resource_quantity_ = + static_cast(resource_quantity * kResourceConversionFactor); } const FractionalResourceQuantity FractionalResourceQuantity::operator+( @@ -76,7 +77,11 @@ ResourceSet::ResourceSet() {} ResourceSet::ResourceSet( const std::unordered_map &resource_map) - : resource_capacity_(resource_map) {} + : resource_capacity_(resource_map) { + for (auto const &resource_pair : resource_map) { + RAY_CHECK(resource_pair.second > 0); + } +} ResourceSet::ResourceSet(const std::unordered_map &resource_map) { for (auto const &resource_pair : resource_map) { @@ -169,7 +174,8 @@ void ResourceSet::SubtractResourcesStrict(const ResourceSet &other) { const std::string &resource_label = resource_pair.first; const FractionalResourceQuantity &resource_capacity = resource_pair.second; RAY_CHECK(resource_capacity_.count(resource_label) == 1) - << "Attempt to acquire unknown resource: " << resource_label; + << "Attempt to acquire unknown resource: " << resource_label << " capacity " + << resource_capacity.ToDouble(); resource_capacity_[resource_label] -= resource_capacity; // Ensure that quantity is positive. Note, we have to have the check before @@ -233,8 +239,10 @@ FractionalResourceQuantity ResourceSet::GetResource( const ResourceSet ResourceSet::GetNumCpus() const { ResourceSet cpu_resource_set; - cpu_resource_set.resource_capacity_[kCPU_ResourceLabel] = - GetResource(kCPU_ResourceLabel); + const FractionalResourceQuantity cpu_quantity = GetResource(kCPU_ResourceLabel); + if (cpu_quantity > 0) { + cpu_resource_set.resource_capacity_[kCPU_ResourceLabel] = cpu_quantity; + } return cpu_resource_set; } diff --git a/src/ray/raylet/scheduling_resources.h b/src/ray/raylet/scheduling_resources.h index 9f64ddae6b45..9e3a2a64ce4a 100644 --- a/src/ray/raylet/scheduling_resources.h +++ b/src/ray/raylet/scheduling_resources.h @@ -58,7 +58,7 @@ class FractionalResourceQuantity { private: /// The resource quantity represented as 1/kResourceConversionFactor-th of a /// unit. - int resource_quantity_; + int64_t resource_quantity_; }; /// \class ResourceSet diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc index 5d6a02186ced..9d8036411303 100644 --- a/src/ray/raylet/task.cc +++ b/src/ray/raylet/task.cc @@ -46,14 +46,18 @@ void Task::CopyTaskExecutionSpec(const Task &task) { ComputeDependencies(); } +const std::string Task::Serialize() const { + flatbuffers::FlatBufferBuilder fbb; + fbb.Finish(ToFlatbuffer(fbb)); + return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); +} + std::string SerializeTaskAsString(const std::vector *dependencies, const TaskSpecification *task_spec) { - flatbuffers::FlatBufferBuilder fbb; std::vector execution_dependencies(*dependencies); TaskExecutionSpecification execution_spec(std::move(execution_dependencies)); Task task(execution_spec, *task_spec); - fbb.Finish(task.ToFlatbuffer(fbb)); - return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); + return task.Serialize(); } } // namespace raylet diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h index b942e2bf2c03..10cdfe5110f4 100644 --- a/src/ray/raylet/task.h +++ b/src/ray/raylet/task.h @@ -84,6 +84,9 @@ class Task { /// \param task Task structure with updated dynamic information. void CopyTaskExecutionSpec(const Task &task); + /// Serialize this task as a string. + const std::string Serialize() const; + private: void ComputeDependencies(); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 43698c53f0d8..d4ac4cf4ecce 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -164,7 +164,10 @@ void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { state.registered_workers.insert(std::move(worker)); auto it = state.starting_worker_processes.find(pid); - RAY_CHECK(it != state.starting_worker_processes.end()); + if (it == state.starting_worker_processes.end()) { + RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid; + return; + } it->second--; if (it->second == 0) { state.starting_worker_processes.erase(it); diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h new file mode 100644 index 000000000000..725652cb5ebc --- /dev/null +++ b/src/ray/rpc/client_call.h @@ -0,0 +1,169 @@ +#ifndef RAY_RPC_CLIENT_CALL_H +#define RAY_RPC_CLIENT_CALL_H + +#include +#include + +#include "ray/common/status.h" +#include "ray/rpc/util.h" + +namespace ray { +namespace rpc { + +/// Represents an outgoing gRPC request. +/// +/// The lifecycle of a `ClientCall` is as follows. +/// +/// When a client submits a new gRPC request, a new `ClientCall` object will be created +/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of +/// `CompletionQueue`. +/// +/// When the reply is received, `ClientCallMangager` will get the address of this object +/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then +/// delete this object. +/// +/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use +/// template. This allows the users (e.g., `ClientCallMangager`) not having to use +/// template as well. +class ClientCall { + public: + /// The callback to be called by `ClientCallManager` when the reply of this request is + /// received. + virtual void OnReplyReceived() = 0; +}; + +class ClientCallManager; + +/// Reprents the client callback function of a particular rpc method. +/// +/// \tparam Reply Type of the reply message. +template +using ClientCallback = std::function; + +/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular +/// RPC method. +/// +/// \tparam Reply Type of the Reply message. +template +class ClientCallImpl : public ClientCall { + public: + void OnReplyReceived() override { callback_(GrpcStatusToRayStatus(status_), reply_); } + + private: + /// Constructor. + /// + /// \param[in] callback The callback function to handle the reply. + ClientCallImpl(const ClientCallback &callback) : callback_(callback) {} + + /// The reply message. + Reply reply_; + + /// The callback function to handle the reply. + ClientCallback callback_; + + /// The response reader. + std::unique_ptr> response_reader_; + + /// gRPC status of this request. + grpc::Status status_; + + /// Context for the client. It could be used to convey extra information to + /// the server and/or tweak certain RPC behaviors. + grpc::ClientContext context_; + + friend class ClientCallManager; +}; + +/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using PrepareAsyncFunction = std::unique_ptr> ( + GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request, + grpc::CompletionQueue *cq); + +/// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of +/// `ClientCall` objects. +/// +/// It maintains a thread that keeps polling events from `CompletionQueue`, and post +/// the callback function to the main event loop when a reply is received. +/// +/// Mutiple clients can share one `ClientCallManager`. +class ClientCallManager { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which the callback functions will be + /// posted. + ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) { + // Start the polling thread. + std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this); + polling_thread.detach(); + } + + ~ClientCallManager() { cq_.Shutdown(); } + + /// Create a new `ClientCall` and send request. + /// + /// \param[in] stub The gRPC-generated stub. + /// \param[in] prepare_async_function Pointer to the gRPC-generated + /// `FooService::Stub::PrepareAsyncBar` function. + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// + /// \tparam GrpcService Type of the gRPC-generated service class. + /// \tparam Request Type of the request message. + /// \tparam Reply Type of the reply message. + template + ClientCall *CreateCall( + typename GrpcService::Stub &stub, + const PrepareAsyncFunction prepare_async_function, + const Request &request, const ClientCallback &callback) { + // Create a new `ClientCall` object. This object will eventuall be deleted in the + // `ClientCallManager::PollEventsFromCompletionQueue` when reply is received. + auto call = new ClientCallImpl(callback); + // Send request. + call->response_reader_ = + (stub.*prepare_async_function)(&call->context_, request, &cq_); + call->response_reader_->StartCall(); + call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call); + return call; + } + + private: + /// This function runs in a background thread. It keeps polling events from the + /// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall` + /// objects. + void PollEventsFromCompletionQueue() { + void *got_tag; + bool ok = false; + // Keep reading events from the `CompletionQueue` until it's shutdown. + while (cq_.Next(&got_tag, &ok)) { + ClientCall *call = reinterpret_cast(got_tag); + if (ok) { + // Post the callback to the main event loop. + main_service_.post([call]() { + call->OnReplyReceived(); + // The call is finished, we can delete the `ClientCall` object now. + delete call; + }); + } else { + delete call; + } + } + } + + /// The main event loop, to which the callback functions will be posted. + boost::asio::io_service &main_service_; + + /// The gRPC `CompletionQueue` object used to poll events. + grpc::CompletionQueue cq_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc new file mode 100644 index 000000000000..feb788da7692 --- /dev/null +++ b/src/ray/rpc/grpc_server.cc @@ -0,0 +1,70 @@ +#include "ray/rpc/grpc_server.h" + +namespace ray { +namespace rpc { + +void GrpcServer::Run() { + std::string server_address("0.0.0.0:" + std::to_string(port_)); + + grpc::ServerBuilder builder; + // TODO(hchen): Add options for authentication. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + // Allow subclasses to register concrete services. + RegisterServices(builder); + // Get hold of the completion queue used for the asynchronous communication + // with the gRPC runtime. + cq_ = builder.AddCompletionQueue(); + // Build and start server. + server_ = builder.BuildAndStart(); + RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; + + // Allow subclasses to initialize the server call factories. + InitServerCallFactories(&server_call_factories_and_concurrencies_); + for (auto &entry : server_call_factories_and_concurrencies_) { + for (int i = 0; i < entry.second; i++) { + // Create and request calls from the factory. + entry.first->CreateCall(); + } + } + // Start a thread that polls incoming requests. + std::thread polling_thread(&GrpcServer::PollEventsFromCompletionQueue, this); + polling_thread.detach(); +} + +void GrpcServer::PollEventsFromCompletionQueue() { + void *tag; + bool ok; + // Keep reading events from the `CompletionQueue` until it's shutdown. + while (cq_->Next(&tag, &ok)) { + ServerCall *server_call = static_cast(tag); + // `ok == false` indicates that the server has been shut down. + // We should delete the call object in this case. + bool delete_call = !ok; + if (ok) { + switch (server_call->GetState()) { + case ServerCallState::PENDING: + // We've received a new incoming request. Now this call object is used to + // track this request. So we need to create another call to handle next + // incoming request. + server_call->GetFactory().CreateCall(); + server_call->SetState(ServerCallState::PROCESSING); + main_service_.post([server_call] { server_call->HandleRequest(); }); + break; + case ServerCallState::SENDING_REPLY: + // The reply has been sent, this call can be deleted now. + // This event is triggered by `ServerCallImpl::SendReply`. + delete_call = true; + break; + default: + RAY_LOG(FATAL) << "Shouldn't reach here."; + break; + } + } + if (delete_call) { + delete server_call; + } + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h new file mode 100644 index 000000000000..4953f470610f --- /dev/null +++ b/src/ray/rpc/grpc_server.h @@ -0,0 +1,92 @@ +#ifndef RAY_RPC_GRPC_SERVER_H +#define RAY_RPC_GRPC_SERVER_H + +#include + +#include +#include + +#include "ray/common/status.h" +#include "ray/rpc/server_call.h" + +namespace ray { +namespace rpc { + +/// Base class that represents an abstract gRPC server. +/// +/// A `GrpcServer` listens on a specific port. It owns +/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, +/// 2) and a thread that polls events from the `ServerCompletionQueue`. +/// +/// Subclasses can register one or multiple services to a `GrpcServer`, see +/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide +/// which kinds of requests this server should accept. +class GrpcServer { + public: + /// Constructor. + /// + /// \param[in] name Name of this server, used for logging and debugging purpose. + /// \param[in] port The port to bind this server to. If it's 0, a random available port + /// will be chosen. + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcServer(const std::string &name, const uint32_t port, + boost::asio::io_service &main_service) + : name_(name), port_(port), main_service_(main_service) {} + + /// Destruct this gRPC server. + ~GrpcServer() { + server_->Shutdown(); + cq_->Shutdown(); + } + + /// Initialize and run this server. + void Run(); + + /// Get the port of this gRPC server. + int GetPort() const { return port_; } + + protected: + /// Subclasses should implement this method and register one or multiple gRPC services + /// to the given `ServerBuilder`. + /// + /// \param[in] builder The `ServerBuilder` instance to register services to. + virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// This function runs in a background thread. It keeps polling events from the + /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances + /// via the `ServerCall` objects. + void PollEventsFromCompletionQueue(); + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + /// Name of this server, used for logging and debugging purpose. + const std::string name_; + /// Port of this server. + int port_; + /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that + /// gRPC server can accept. + std::vector, int>> + server_call_factories_and_concurrencies_; + /// The `ServerCompletionQueue` object used for polling events. + std::unique_ptr cq_; + /// The `Server` object. + std::unique_ptr server_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/node_manager_client.h b/src/ray/rpc/node_manager_client.h new file mode 100644 index 000000000000..005c75db40d2 --- /dev/null +++ b/src/ray/rpc/node_manager_client.h @@ -0,0 +1,56 @@ +#ifndef RAY_RPC_NODE_MANAGER_CLIENT_H +#define RAY_RPC_NODE_MANAGER_CLIENT_H + +#include + +#include + +#include "ray/common/status.h" +#include "ray/rpc/client_call.h" +#include "ray/util/logging.h" +#include "src/ray/protobuf/node_manager.grpc.pb.h" +#include "src/ray/protobuf/node_manager.pb.h" + +namespace ray { +namespace rpc { + +/// Client used for communicating with a remote node manager server. +class NodeManagerClient { + public: + /// Constructor. + /// + /// \param[in] address Address of the node manager server. + /// \param[in] port Port of the node manager server. + /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. + NodeManagerClient(const std::string &address, const int port, + ClientCallManager &client_call_manager) + : client_call_manager_(client_call_manager) { + std::shared_ptr channel = grpc::CreateChannel( + address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); + stub_ = NodeManagerService::NewStub(channel); + }; + + /// Forward a task and its uncommitted lineage. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + void ForwardTask(const ForwardTaskRequest &request, + const ClientCallback &callback) { + client_call_manager_ + .CreateCall( + *stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request, + callback); + } + + private: + /// The gRPC-generated stub. + std::unique_ptr stub_; + + /// The `ClientCallManager` used for managing requests. + ClientCallManager &client_call_manager_; +}; + +} // namespace rpc +} // namespace ray + +#endif // RAY_RPC_NODE_MANAGER_CLIENT_H diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h new file mode 100644 index 000000000000..afaea299ea89 --- /dev/null +++ b/src/ray/rpc/node_manager_server.h @@ -0,0 +1,71 @@ +#ifndef RAY_RPC_NODE_MANAGER_SERVER_H +#define RAY_RPC_NODE_MANAGER_SERVER_H + +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/server_call.h" + +#include "src/ray/protobuf/node_manager.grpc.pb.h" +#include "src/ray/protobuf/node_manager.pb.h" + +namespace ray { +namespace rpc { + +/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. +class NodeManagerServiceHandler { + public: + /// Handle a `ForwardTask` request. + /// The implementation can handle this request asynchronously. When hanling is done, the + /// `done_callback` should be called. + /// + /// \param[in] request The request message. + /// \param[out] reply The reply message. + /// \param[in] done_callback The callback to be called when the request is done. + virtual void HandleForwardTask(const ForwardTaskRequest &request, + ForwardTaskReply *reply, + RequestDoneCallback done_callback) = 0; +}; + +/// The `GrpcServer` for `NodeManagerService`. +class NodeManagerServer : public GrpcServer { + public: + /// Constructor. + /// + /// \param[in] port See super class. + /// \param[in] main_service See super class. + /// \param[in] handler The service handler that actually handle the requests. + NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, + NodeManagerServiceHandler &service_handler) + : GrpcServer("NodeManager", port, main_service), + service_handler_(service_handler){}; + + void RegisterServices(grpc::ServerBuilder &builder) override { + /// Register `NodeManagerService`. + builder.RegisterService(&service_); + } + + void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) override { + // Initialize the factory for `ForwardTask` requests. + std::unique_ptr forward_task_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestForwardTask, + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + + // Set `ForwardTask`'s accept concurrency to 100. + server_call_factories_and_concurrencies->emplace_back( + std::move(forward_task_call_factory), 100); + } + + private: + /// The grpc async service object. + NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. + NodeManagerServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h new file mode 100644 index 000000000000..e06278260ab6 --- /dev/null +++ b/src/ray/rpc/server_call.h @@ -0,0 +1,233 @@ +#ifndef RAY_RPC_SERVER_CALL_H +#define RAY_RPC_SERVER_CALL_H + +#include + +#include "ray/common/status.h" +#include "ray/rpc/util.h" + +namespace ray { +namespace rpc { + +/// Represents the callback function to be called when a `ServiceHandler` finishes +/// handling a request. +using RequestDoneCallback = std::function; + +/// Represents state of a `ServerCall`. +enum class ServerCallState { + /// The call is created and waiting for an incoming request. + PENDING, + /// Request is received and being processed. + PROCESSING, + /// Request processing is done, and reply is being sent to client. + SENDING_REPLY +}; + +class ServerCallFactory; + +/// Reprensents an incoming request of a gRPC server. +/// +/// The lifecycle and state transition of a `ServerCall` is as follows: +/// +/// --(1)--> PENDING --(2)--> PROCESSING --(3)--> SENDING_REPLY --(4)--> [FINISHED] +/// +/// (1) The `GrpcServer` creates a `ServerCall` and use it as the tag to accept requests +/// gRPC `CompletionQueue`. Now the state is `PENDING`. +/// (2) When a request is received, an event will be gotten from the `CompletionQueue`. +/// `GrpcServer` then should change `ServerCall`'s state to PROCESSING and call +/// `ServerCall::HandleRequest`. +/// (3) When the `ServiceHandler` finishes handling the request, `ServerCallImpl::Finish` +/// will be called, and the state becomes `SENDING_REPLY`. +/// (4) When the reply is sent, an event will be gotten from the `CompletionQueue`. +/// `GrpcServer` will then delete this call. +/// +/// NOTE(hchen): Compared to `ServerCallImpl`, this abstract interface doesn't use +/// template. This allows the users (e.g., `GrpcServer`) not having to use +/// template as well. +class ServerCall { + public: + /// Get the state of this `ServerCall`. + virtual ServerCallState GetState() const = 0; + + /// Set state of this `ServerCall`. + virtual void SetState(const ServerCallState &new_state) = 0; + + /// Handle the requst. This is the callback function to be called by + /// `GrpcServer` when the request is received. + virtual void HandleRequest() = 0; + + /// Get the factory that created this `ServerCall`. + virtual const ServerCallFactory &GetFactory() const = 0; +}; + +/// The factory that creates a particular kind of `ServerCall` objects. +class ServerCallFactory { + public: + /// Create a new `ServerCall` and request gRPC runtime to start accepting the + /// corresonding type of requests. + /// + /// \return Pointer to the `ServerCall` object. + virtual ServerCall *CreateCall() const = 0; +}; + +/// Represents the generic signature of a `FooServiceHandler::HandleBar()` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *, + RequestDoneCallback); + +/// Implementation of `ServerCall`. It represents `ServerCall` for a particular +/// RPC method. +/// +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +class ServerCallImpl : public ServerCall { + public: + /// Constructor. + /// + /// \param[in] factory The factory which created this call. + /// \param[in] service_handler The service handler that handles the request. + /// \param[in] handle_request_function Pointer to the service handler function. + ServerCallImpl( + const ServerCallFactory &factory, ServiceHandler &service_handler, + HandleRequestFunction handle_request_function) + : state_(ServerCallState::PENDING), + factory_(factory), + service_handler_(service_handler), + handle_request_function_(handle_request_function), + response_writer_(&context_) {} + + ServerCallState GetState() const override { return state_; } + + void SetState(const ServerCallState &new_state) override { state_ = new_state; } + + void HandleRequest() override { + state_ = ServerCallState::PROCESSING; + (service_handler_.*handle_request_function_)(request_, &reply_, + [this](Status status) { + // When the handler is done with the + // request, tell gRPC to finish this + // request. + SendReply(status); + }); + } + + const ServerCallFactory &GetFactory() const override { return factory_; } + + private: + /// Tell gRPC to finish this request. + void SendReply(Status status) { + state_ = ServerCallState::SENDING_REPLY; + response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this); + } + + /// State of this call. + ServerCallState state_; + + /// The factory which created this call. + const ServerCallFactory &factory_; + + /// The service handler that handles the request. + ServiceHandler &service_handler_; + + /// Pointer to the service handler function. + HandleRequestFunction handle_request_function_; + + /// Context for the request, allowing to tweak aspects of it such as the use + /// of compression, authentication, as well as to send metadata back to the client. + grpc::ServerContext context_; + + /// The reponse writer. + grpc::ServerAsyncResponseWriter response_writer_; + + /// The request message. + Request request_; + + /// The reply message. + Reply reply_; + + template + friend class ServerCallFactoryImpl; +}; + +/// Represents the generic signature of a `FooService::AsyncService::RequestBar()` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using RequestCallFunction = void (GrpcService::AsyncService::*)( + grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter *, + grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); + +/// Implementation of `ServerCallFactory` +/// +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +class ServerCallFactoryImpl : public ServerCallFactory { + using AsyncService = typename GrpcService::AsyncService; + + public: + /// Constructor. + /// + /// \param[in] service The gRPC-generated `AsyncService`. + /// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod` + // function. + /// \param[in] service_handler The service handler that handles the request. + /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] cq The `CompletionQueue`. + ServerCallFactoryImpl( + AsyncService &service, + RequestCallFunction request_call_function, + ServiceHandler &service_handler, + HandleRequestFunction handle_request_function, + const std::unique_ptr &cq) + : service_(service), + request_call_function_(request_call_function), + service_handler_(service_handler), + handle_request_function_(handle_request_function), + cq_(cq) {} + + ServerCall *CreateCall() const override { + // Create a new `ServerCall`. This object will eventually be deleted by + // `GrpcServer::PollEventsFromCompletionQueue`. + auto call = new ServerCallImpl( + *this, service_handler_, handle_request_function_); + /// Request gRPC runtime to starting accepting this kind of request, using the call as + /// the tag. + (service_.*request_call_function_)(&call->context_, &call->request_, + &call->response_writer_, cq_.get(), cq_.get(), + call); + return call; + } + + private: + /// The gRPC-generated `AsyncService`. + AsyncService &service_; + + /// Pointer to the `AsyncService::RequestMethod` function. + RequestCallFunction request_call_function_; + + /// The service handler that handles the request. + ServiceHandler &service_handler_; + + /// Pointer to the service handler function. + HandleRequestFunction handle_request_function_; + + /// The `CompletionQueue`. + const std::unique_ptr &cq_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h new file mode 100644 index 000000000000..6ecc6c3c4a34 --- /dev/null +++ b/src/ray/rpc/util.h @@ -0,0 +1,33 @@ +#ifndef RAY_RPC_UTIL_H +#define RAY_RPC_UTIL_H + +#include + +#include "ray/common/status.h" + +namespace ray { +namespace rpc { + +/// Helper function that converts a ray status to gRPC status. +inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { + if (ray_status.ok()) { + return grpc::Status::OK; + } else { + // TODO(hchen): Use more specific error code. + return grpc::Status(grpc::StatusCode::UNKNOWN, ray_status.message()); + } +} + +/// Helper function that converts a gRPC status to ray status. +inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { + if (grpc_status.ok()) { + return Status::OK(); + } else { + return Status::IOError(grpc_status.error_message()); + } +} + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh index 5f1dd2eda69f..7668b92ac272 100644 --- a/src/ray/test/run_core_worker_tests.sh +++ b/src/ray/test/run_core_worker_tests.sh @@ -6,7 +6,7 @@ set -e set -x -bazel build "//:core_worker_test" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" +bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" # Get the directory in which this script is executing. SCRIPT_DIR="`dirname \"$0\"`" @@ -26,6 +26,7 @@ REDIS_MODULE="./bazel-bin/libray_redis_module.so" LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}" STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server" RAYLET_EXEC="./bazel-bin/raylet" +MOCK_WORKER_EXEC="./bazel-bin/mock_worker" # Allow cleanup commands to fail. bazel run //:redis-cli -- -p 6379 shutdown || true @@ -37,11 +38,8 @@ sleep 2s bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & sleep 2s # Run tests. -./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC +./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC $MOCK_WORKER_EXEC sleep 1s bazel run //:redis-cli -- -p 6379 shutdown bazel run //:redis-cli -- -p 6380 shutdown sleep 1s - -# Include raylet integration test once it's ready. -# ./bazel-bin/object_manager_integration_test $STORE_EXEC diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 97c871d8cb7f..1e2a95408f13 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -186,8 +186,7 @@ bool RayLog::IsLevelEnabled(RayLogLevel log_level) { RayLog::RayLog(const char *file_name, int line_number, RayLogLevel severity) // glog does not have DEBUG level, we can handle it using is_enabled_. - : logging_provider_(nullptr), - is_enabled_(severity >= severity_threshold_) { + : logging_provider_(nullptr), is_enabled_(severity >= severity_threshold_) { #ifdef RAY_USE_GLOG if (is_enabled_) { logging_provider_ = diff --git a/src/ray/util/logging.h b/src/ray/util/logging.h index d37ab9a73897..39428eba9583 100644 --- a/src/ray/util/logging.h +++ b/src/ray/util/logging.h @@ -16,19 +16,19 @@ enum class RayLogLevel { DEBUG = -1, INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 #define RAY_IGNORE_EXPR(expr) ((void)(expr)) -#define RAY_CHECK(condition) \ - (condition) ? RAY_IGNORE_EXPR(0) \ - : ::ray::Voidify() & \ - ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::FATAL) \ - << " Check failed: " #condition " " +#define RAY_CHECK(condition) \ + (condition) \ + ? RAY_IGNORE_EXPR(0) \ + : ::ray::Voidify() & ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::FATAL) \ + << " Check failed: " #condition " " #ifdef NDEBUG -#define RAY_DCHECK(condition) \ - (condition) ? RAY_IGNORE_EXPR(0) \ - : ::ray::Voidify() & \ - ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::ERROR) \ - << " Debug check failed: " #condition " " +#define RAY_DCHECK(condition) \ + (condition) \ + ? RAY_IGNORE_EXPR(0) \ + : ::ray::Voidify() & ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::ERROR) \ + << " Debug check failed: " #condition " " #else #define RAY_DCHECK(condition) RAY_CHECK(condition) diff --git a/src/ray/util/macros.h b/src/ray/util/macros.h index dbf85fe399e5..f105c4bd2b5f 100644 --- a/src/ray/util/macros.h +++ b/src/ray/util/macros.h @@ -8,7 +8,7 @@ void operator=(const TypeName &) = delete #endif -#define RAY_UNUSED(x) (void) x +#define RAY_UNUSED(x) (void)x // // GCC can be told that a certain branch is not likely to be taken (for