Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime env] URI reference refactor #22828

Merged
merged 27 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dashboard/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@
# Default value for datacenter (the default value in protobuf)
DEFAULT_LANGUAGE = "PYTHON"
DEFAULT_JOB_ID = "ffff"
# Cache TTL for bad runtime env. After this time, delete the cache and retry to create
# runtime env if needed.
BAD_RUNTIME_ENV_CACHE_TTL_SECONDS = env_integer(
"BAD_RUNTIME_ENV_CACHE_TTL_SECONDS", 60 * 10
)
360 changes: 268 additions & 92 deletions dashboard/modules/runtime_env/runtime_env_agent.py

Large diffs are not rendered by default.

9 changes: 0 additions & 9 deletions python/ray/_private/runtime_env/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC, abstractstaticmethod

from typing import Tuple
from ray.util.annotations import DeveloperAPI
from ray._private.runtime_env.context import RuntimeEnvContext

Expand All @@ -11,14 +10,6 @@ def encode_plugin_uri(plugin: str, uri: str) -> str:
return plugin + "|" + uri


def decode_plugin_uri(plugin_uri: str) -> Tuple[str, str]:
if "|" not in plugin_uri:
raise ValueError(
f"Plugin URI must be of the form 'plugin|uri', not {plugin_uri}"
)
return tuple(plugin_uri.split("|", 2))


@DeveloperAPI
class RuntimeEnvPlugin(ABC):
@abstractstaticmethod
Expand Down
1 change: 1 addition & 0 deletions python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ py_test_module_list(
py_test_module_list(
files = [
"test_runtime_env.py",
"test_runtime_env_2.py",
"test_runtime_env_working_dir.py",
"test_runtime_env_working_dir_2.py",
"test_runtime_env_working_dir_3.py",
Expand Down
8 changes: 8 additions & 0 deletions python/ray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,11 @@ def listen_port(request):
yield port
finally:
sock.close()


@pytest.fixture
def set_bad_runtime_env_cache_ttl_seconds(request):
ttl = getattr(request, "param", "0")
os.environ["BAD_RUNTIME_ENV_CACHE_TTL_SECONDS"] = ttl
yield ttl
del os.environ["BAD_RUNTIME_ENV_CACHE_TTL_SECONDS"]
43 changes: 0 additions & 43 deletions python/ray/tests/test_runtime_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,49 +145,6 @@ def test_container_option_serialize(runtime_env_class):
assert job_config_serialized.count(b"--name=test") == 1


@pytest.mark.skipif(
sys.platform == "win32", reason="conda in runtime_env unsupported on Windows."
)
@pytest.mark.parametrize("runtime_env_class", [dict, RuntimeEnv])
def test_invalid_conda_env(shutdown_only, runtime_env_class):
ray.init()

@ray.remote
def f():
pass

@ray.remote
class A:
def f(self):
pass

start = time.time()
bad_env = runtime_env_class(conda={"dependencies": ["this_doesnt_exist"]})
with pytest.raises(
RuntimeEnvSetupError,
# The actual error message should be included in the exception.
match="ResolvePackageNotFound",
):
ray.get(f.options(runtime_env=bad_env).remote())
first_time = time.time() - start

# Check that another valid task can run.
ray.get(f.remote())

a = A.options(runtime_env=bad_env).remote()
with pytest.raises(
ray.exceptions.RuntimeEnvSetupError, match="ResolvePackageNotFound"
):
ray.get(a.f.remote())

# The second time this runs it should be faster as the error is cached.
start = time.time()
with pytest.raises(RuntimeEnvSetupError, match="ResolvePackageNotFound"):
ray.get(f.options(runtime_env=bad_env).remote())

assert (time.time() - start) < (first_time / 2.0)


@pytest.mark.skipif(
sys.platform == "win32", reason="runtime_env unsupported on Windows."
)
Expand Down
78 changes: 78 additions & 0 deletions python/ray/tests/test_runtime_env_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import sys
import time

import ray
from ray.exceptions import RuntimeEnvSetupError
from ray.runtime_env import RuntimeEnv


bad_runtime_env_cache_ttl_seconds = 10


@pytest.mark.skipif(
sys.platform == "win32", reason="conda in runtime_env unsupported on Windows."
)
@pytest.mark.parametrize("runtime_env_class", [dict, RuntimeEnv])
@pytest.mark.parametrize(
"set_bad_runtime_env_cache_ttl_seconds",
[
str(bad_runtime_env_cache_ttl_seconds),
],
indirect=True,
)
def test_invalid_conda_env(
shutdown_only, runtime_env_class, set_bad_runtime_env_cache_ttl_seconds
):
ray.init()

@ray.remote
def f():
pass

@ray.remote
class A:
def f(self):
pass

start = time.time()
bad_env = runtime_env_class(conda={"dependencies": ["this_doesnt_exist"]})
with pytest.raises(
RuntimeEnvSetupError,
# The actual error message should be included in the exception.
match="ResolvePackageNotFound",
):
ray.get(f.options(runtime_env=bad_env).remote())
first_time = time.time() - start

# Check that another valid task can run.
ray.get(f.remote())

a = A.options(runtime_env=bad_env).remote()
with pytest.raises(
ray.exceptions.RuntimeEnvSetupError, match="ResolvePackageNotFound"
):
ray.get(a.f.remote())

# The second time this runs it should be faster as the error is cached.
start = time.time()
with pytest.raises(RuntimeEnvSetupError, match="ResolvePackageNotFound"):
ray.get(f.options(runtime_env=bad_env).remote())

assert (time.time() - start) < (first_time / 2.0)

# Sleep to wait bad runtime env cache removed.
time.sleep(bad_runtime_env_cache_ttl_seconds)

# The third time this runs it should be slower as the error isn't cached.
start = time.time()
with pytest.raises(RuntimeEnvSetupError, match="ResolvePackageNotFound"):
ray.get(f.options(runtime_env=bad_env).remote())

assert (time.time() - start) > (first_time / 2.0)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-sv", __file__]))
8 changes: 1 addition & 7 deletions python/ray/tests/test_runtime_env_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
parse_and_validate_env_vars,
parse_and_validate_py_modules,
)
from ray._private.runtime_env.plugin import decode_plugin_uri, encode_plugin_uri
from ray._private.runtime_env.plugin import encode_plugin_uri
from ray.runtime_env import RuntimeEnv

CONDA_DICT = {"dependencies": ["pip", {"pip": ["pip-install-test==0.5"]}]}
Expand Down Expand Up @@ -55,12 +55,6 @@ def test_encode_plugin_uri():
assert encode_plugin_uri("plugin", "uri") == "plugin|uri"


def test_decode_plugin_uri():
with pytest.raises(ValueError):
decode_plugin_uri("no_vertical_bar_separator")
assert decode_plugin_uri("plugin|uri") == ("plugin", "uri")


class TestValidateWorkingDir:
def test_validate_bad_uri(self):
with pytest.raises(ValueError, match="a valid URI"):
Expand Down
21 changes: 14 additions & 7 deletions python/ray/util/client/server/proxier.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,25 +214,32 @@ def create_specific_server(self, client_id: str) -> SpecificServer:
self.servers[client_id] = server
return server

def _create_runtime_env(
def _increase_runtime_env_reference(
self, serialized_runtime_env: str, specific_server: SpecificServer
):
"""Creates the runtime_env by sending an RPC to the agent.
"""Increase the runtime_env reference by sending an RPC to the agent.

Includes retry logic to handle the case when the agent is
temporarily unreachable (e.g., hasn't been started up yet).
"""
create_env_request = runtime_env_agent_pb2.CreateRuntimeEnvRequest(
logger.info(
f"Increasing runtime env reference for "
f"ray_client_server_{specific_server.port}."
f"Serialized runtime env is {serialized_runtime_env}."
)

create_env_request = runtime_env_agent_pb2.GetOrCreateRuntimeEnvRequest(
serialized_runtime_env=serialized_runtime_env,
job_id=f"ray_client_server_{specific_server.port}".encode("utf-8"),
source_process="client_server",
)

retries = 0
max_retries = 5
wait_time_s = 0.5
while retries <= max_retries:
try:
r = self._runtime_env_stub.CreateRuntimeEnv(create_env_request)
r = self._runtime_env_stub.GetOrCreateRuntimeEnv(create_env_request)
if r.status == agent_manager_pb2.AgentRpcStatus.AGENT_RPC_STATUS_OK:
return r.serialized_runtime_env_context
elif (
Expand All @@ -256,7 +263,7 @@ def _create_runtime_env(
raise e

logger.warning(
f"CreateRuntimeEnv request failed: {e}. "
f"GetOrCreateRuntimeEnv request failed: {e}. "
f"Retrying after {wait_time_s}s. "
f"{max_retries-retries} retries remaining."
)
Expand All @@ -267,7 +274,7 @@ def _create_runtime_env(
wait_time_s *= 2

raise TimeoutError(
f"CreateRuntimeEnv request failed after {max_retries} attempts."
f"GetOrCreateRuntimeEnv request failed after {max_retries} attempts."
)

def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool:
Expand All @@ -288,7 +295,7 @@ def start_specific_server(self, client_id: str, job_config: JobConfig) -> bool:
# to the agent?
serialized_runtime_env_context = RuntimeEnvContext().serialize()
else:
serialized_runtime_env_context = self._create_runtime_env(
serialized_runtime_env_context = self._increase_runtime_env_reference(
serialized_runtime_env=serialized_runtime_env,
specific_server=specific_server,
)
Expand Down
2 changes: 1 addition & 1 deletion src/mock/ray/raylet/agent_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MockAgentManager : public AgentManager {
rpc::SendReplyCallback send_reply_callback),
(override));
MOCK_METHOD(void,
CreateRuntimeEnv,
GetOrCreateRuntimeEnv,
(const JobID &job_id,
const std::string &serialized_runtime_env,
CreateRuntimeEnvCallback callback),
Expand Down
18 changes: 11 additions & 7 deletions src/ray/protobuf/runtime_env_agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,35 @@ package ray.rpc;

import "src/ray/protobuf/agent_manager.proto";

message CreateRuntimeEnvRequest {
message GetOrCreateRuntimeEnvRequest {
string serialized_runtime_env = 1;
bytes job_id = 2;
// Serialized allocated resource instances. Key is resource type, value is allocated
// instances. For example,{"CPU":20000,"memory":40000,"GPU":[10000, 10000]} means 2 cpu
// cores, 2 Gi memory, GPU 0 and GPU 1.
string serialized_allocated_resource_instances = 3;
string source_process = 4;
}

message CreateRuntimeEnvReply {
message GetOrCreateRuntimeEnvReply {
AgentRpcStatus status = 1;
string error_message = 2;
string serialized_runtime_env_context = 3;
}

message DeleteURIsRequest {
repeated string uris = 1;
message DeleteRuntimeEnvIfPossibleRequest {
string serialized_runtime_env = 1;
string source_process = 2;
}

message DeleteURIsReply {
message DeleteRuntimeEnvIfPossibleReply {
AgentRpcStatus status = 1;
string error_message = 2;
}

service RuntimeEnvService {
rpc CreateRuntimeEnv(CreateRuntimeEnvRequest) returns (CreateRuntimeEnvReply);
rpc DeleteURIs(DeleteURIsRequest) returns (DeleteURIsReply);
rpc GetOrCreateRuntimeEnv(GetOrCreateRuntimeEnvRequest)
returns (GetOrCreateRuntimeEnvReply);
rpc DeleteRuntimeEnvIfPossible(DeleteRuntimeEnvIfPossibleRequest)
returns (DeleteRuntimeEnvIfPossibleReply);
}
Loading