-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#22714 added `serve run` to the Serve CLI. This change allows the user to specify a local or remote `working_dir` in `serve run`.
- Loading branch information
1 parent
d1009c8
commit ab2741d
Showing
13 changed files
with
913 additions
and
361 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
import dataclasses | ||
import importlib | ||
import logging | ||
import json | ||
import yaml | ||
from pathlib import Path | ||
import tempfile | ||
from typing import Any, Dict, List, Optional | ||
from pkg_resources import packaging | ||
|
||
try: | ||
import aiohttp | ||
import requests | ||
except ImportError: | ||
aiohttp = None | ||
requests = None | ||
|
||
from ray._private.runtime_env.packaging import ( | ||
create_package, | ||
get_uri_for_directory, | ||
parse_uri, | ||
) | ||
from ray.dashboard.modules.job.common import uri_to_http_components | ||
|
||
from ray.ray_constants import DEFAULT_DASHBOARD_PORT | ||
from ray.util.annotations import PublicAPI | ||
from ray.client_builder import _split_address | ||
from ray.autoscaler._private.cli_logger import cli_logger | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def parse_runtime_env_args( | ||
runtime_env: Optional[str] = None, | ||
runtime_env_json: Optional[str] = None, | ||
working_dir: Optional[str] = None, | ||
): | ||
""" | ||
Generates a runtime_env dictionary using `runtime_env`, `runtime_env_json`, | ||
and `working_dir` CLI options. Only one of `runtime_env` or | ||
`runtime_env_json` may be defined. `working_dir` overwrites the | ||
`working_dir` from any other option. | ||
""" | ||
|
||
final_runtime_env = {} | ||
if runtime_env is not None: | ||
if runtime_env_json is not None: | ||
raise ValueError( | ||
"Only one of --runtime_env and --runtime-env-json can be provided." | ||
) | ||
with open(runtime_env, "r") as f: | ||
final_runtime_env = yaml.safe_load(f) | ||
|
||
elif runtime_env_json is not None: | ||
final_runtime_env = json.loads(runtime_env_json) | ||
|
||
if working_dir is not None: | ||
if "working_dir" in final_runtime_env: | ||
cli_logger.warning( | ||
"Overriding runtime_env working_dir with --working-dir option" | ||
) | ||
|
||
final_runtime_env["working_dir"] = working_dir | ||
|
||
return final_runtime_env | ||
|
||
|
||
@dataclasses.dataclass | ||
class ClusterInfo: | ||
address: str | ||
cookies: Optional[Dict[str, Any]] = None | ||
metadata: Optional[Dict[str, Any]] = None | ||
headers: Optional[Dict[str, Any]] = None | ||
|
||
|
||
def get_submission_client_cluster_info( | ||
address: str, | ||
# For backwards compatibility | ||
*, | ||
# only used in importlib case in parse_cluster_info, but needed | ||
# in function signature. | ||
create_cluster_if_needed: Optional[bool] = False, | ||
cookies: Optional[Dict[str, Any]] = None, | ||
metadata: Optional[Dict[str, Any]] = None, | ||
headers: Optional[Dict[str, Any]] = None, | ||
_use_tls: Optional[bool] = False, | ||
) -> ClusterInfo: | ||
"""Get address, cookies, and metadata used for SubmissionClient. | ||
If no port is specified in `address`, the Ray dashboard default will be | ||
inserted. | ||
Args: | ||
address (str): Address without the module prefix that is passed | ||
to SubmissionClient. | ||
create_cluster_if_needed (bool): Indicates whether the cluster | ||
of the address returned needs to be running. Ray doesn't | ||
start a cluster before interacting with jobs, but other | ||
implementations may do so. | ||
Returns: | ||
ClusterInfo object consisting of address, cookies, and metadata | ||
for SubmissionClient to use. | ||
""" | ||
|
||
scheme = "https" if _use_tls else "http" | ||
|
||
split = address.split(":") | ||
host = split[0] | ||
if len(split) == 1: | ||
port = DEFAULT_DASHBOARD_PORT | ||
elif len(split) == 2: | ||
port = int(split[1]) | ||
else: | ||
raise ValueError(f"Invalid address: {address}.") | ||
|
||
return ClusterInfo( | ||
address=f"{scheme}://{host}:{port}", | ||
cookies=cookies, | ||
metadata=metadata, | ||
headers=headers, | ||
) | ||
|
||
|
||
def parse_cluster_info( | ||
address: str, | ||
create_cluster_if_needed: bool = False, | ||
cookies: Optional[Dict[str, Any]] = None, | ||
metadata: Optional[Dict[str, Any]] = None, | ||
headers: Optional[Dict[str, Any]] = None, | ||
) -> ClusterInfo: | ||
module_string, inner_address = _split_address(address) | ||
|
||
# If user passes http(s):// or ray://, go through normal parsing. | ||
if module_string in {"http", "https", "ray"}: | ||
return get_submission_client_cluster_info( | ||
inner_address, | ||
create_cluster_if_needed=create_cluster_if_needed, | ||
cookies=cookies, | ||
metadata=metadata, | ||
headers=headers, | ||
_use_tls=module_string == "https", | ||
) | ||
# Try to dynamically import the function to get cluster info. | ||
else: | ||
try: | ||
module = importlib.import_module(module_string) | ||
except Exception: | ||
raise RuntimeError( | ||
f"Module: {module_string} does not exist.\n" | ||
f"This module was parsed from Address: {address}" | ||
) from None | ||
assert "get_submission_client_cluster_info" in dir(module), ( | ||
f"Module: {module_string} does " | ||
"not have `get_submission_client_cluster_info`." | ||
) | ||
|
||
return module.get_submission_client_cluster_info( | ||
inner_address, | ||
create_cluster_if_needed=create_cluster_if_needed, | ||
cookies=cookies, | ||
metadata=metadata, | ||
headers=headers, | ||
) | ||
|
||
|
||
class SubmissionClient: | ||
def __init__( | ||
self, | ||
address: str, | ||
create_cluster_if_needed=False, | ||
cookies: Optional[Dict[str, Any]] = None, | ||
metadata: Optional[Dict[str, Any]] = None, | ||
headers: Optional[Dict[str, Any]] = None, | ||
): | ||
|
||
cluster_info = parse_cluster_info( | ||
address, create_cluster_if_needed, cookies, metadata, headers | ||
) | ||
self._address = cluster_info.address | ||
self._cookies = cluster_info.cookies | ||
self._default_metadata = cluster_info.metadata or {} | ||
# Headers used for all requests sent to job server, optional and only | ||
# needed for cases like authentication to remote cluster. | ||
self._headers = cluster_info.headers | ||
|
||
def _check_connection_and_version( | ||
self, min_version: str = "1.9", version_error_message: str = None | ||
): | ||
if version_error_message is None: | ||
version_error_message = ( | ||
f"Please ensure the cluster is running Ray {min_version} or higher." | ||
) | ||
|
||
try: | ||
r = self._do_request("GET", "/api/version") | ||
if r.status_code == 404: | ||
raise RuntimeError(version_error_message) | ||
r.raise_for_status() | ||
|
||
running_ray_version = r.json()["ray_version"] | ||
if packaging.version.parse(running_ray_version) < packaging.version.parse( | ||
min_version | ||
): | ||
raise RuntimeError(version_error_message) | ||
# TODO(edoakes): check the version if/when we break compatibility. | ||
except requests.exceptions.ConnectionError: | ||
raise ConnectionError( | ||
f"Failed to connect to Ray at address: {self._address}." | ||
) | ||
|
||
def _raise_error(self, r: "requests.Response"): | ||
raise RuntimeError( | ||
f"Request failed with status code {r.status_code}: {r.text}." | ||
) | ||
|
||
def _do_request( | ||
self, | ||
method: str, | ||
endpoint: str, | ||
*, | ||
data: Optional[bytes] = None, | ||
json_data: Optional[dict] = None, | ||
) -> "requests.Response": | ||
url = self._address + endpoint | ||
logger.debug(f"Sending request to {url} with json data: {json_data or {}}.") | ||
return requests.request( | ||
method, | ||
url, | ||
cookies=self._cookies, | ||
data=data, | ||
json=json_data, | ||
headers=self._headers, | ||
) | ||
|
||
def _package_exists( | ||
self, | ||
package_uri: str, | ||
) -> bool: | ||
protocol, package_name = uri_to_http_components(package_uri) | ||
r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}") | ||
|
||
if r.status_code == 200: | ||
logger.debug(f"Package {package_uri} already exists.") | ||
return True | ||
elif r.status_code == 404: | ||
logger.debug(f"Package {package_uri} does not exist.") | ||
return False | ||
else: | ||
self._raise_error(r) | ||
|
||
def _upload_package( | ||
self, | ||
package_uri: str, | ||
package_path: str, | ||
include_parent_dir: Optional[bool] = False, | ||
excludes: Optional[List[str]] = None, | ||
) -> bool: | ||
logger.info(f"Uploading package {package_uri}.") | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
protocol, package_name = uri_to_http_components(package_uri) | ||
package_file = Path(tmp_dir) / package_name | ||
create_package( | ||
package_path, | ||
package_file, | ||
include_parent_dir=include_parent_dir, | ||
excludes=excludes, | ||
) | ||
try: | ||
r = self._do_request( | ||
"PUT", | ||
f"/api/packages/{protocol}/{package_name}", | ||
data=package_file.read_bytes(), | ||
) | ||
if r.status_code != 200: | ||
self._raise_error(r) | ||
finally: | ||
package_file.unlink() | ||
|
||
def _upload_package_if_needed( | ||
self, package_path: str, excludes: Optional[List[str]] = None | ||
) -> str: | ||
package_uri = get_uri_for_directory(package_path, excludes=excludes) | ||
if not self._package_exists(package_uri): | ||
self._upload_package(package_uri, package_path, excludes=excludes) | ||
else: | ||
logger.info(f"Package {package_uri} already exists, skipping upload.") | ||
|
||
return package_uri | ||
|
||
def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]): | ||
if "working_dir" in runtime_env: | ||
working_dir = runtime_env["working_dir"] | ||
try: | ||
parse_uri(working_dir) | ||
is_uri = True | ||
logger.debug("working_dir is already a valid URI.") | ||
except ValueError: | ||
is_uri = False | ||
|
||
if not is_uri: | ||
logger.debug("working_dir is not a URI, attempting to upload.") | ||
package_uri = self._upload_package_if_needed( | ||
working_dir, excludes=runtime_env.get("excludes", None) | ||
) | ||
runtime_env["working_dir"] = package_uri | ||
|
||
@PublicAPI(stability="beta") | ||
def get_version(self) -> str: | ||
r = self._do_request("GET", "/api/version") | ||
if r.status_code == 200: | ||
return r.json().get("version") | ||
else: | ||
self._raise_error(r) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.