diff --git a/python/ray/tests/test_client_reconnect.py b/python/ray/tests/test_client_reconnect.py index 155187dbc5bd..2d8ec54b7c8a 100644 --- a/python/ray/tests/test_client_reconnect.py +++ b/python/ray/tests/test_client_reconnect.py @@ -1,11 +1,9 @@ from concurrent import futures -import asyncio import contextlib import os import threading import sys import grpc -import numpy as np import time import random @@ -127,8 +125,7 @@ def _call_inner_function( context.set_code(e.code()) context.set_details(e.details()) raise - if self.on_response and method != "GetObject": - # GetObject streams response, handle on_response separately + if self.on_response: self.on_response(response) return response @@ -162,10 +159,7 @@ def Terminate(self, req, context=None): return self._call_inner_function(req, context, "Terminate") def GetObject(self, request, context=None): - for response in self._call_inner_function(request, context, "GetObject"): - if self.on_response: - self.on_response(response) - yield response + return self._call_inner_function(request, context, "GetObject") def PutObject( self, request: ray_client_pb2.PutRequest, context=None @@ -276,8 +270,8 @@ def start_middleman_server( real_addr="localhost:50051", on_log_response=on_log_response, on_data_response=on_data_response, - on_task_request=on_task_request, - on_task_response=on_task_response, + on_task_request=on_task_response, + on_task_response=on_task_request, ) middleman.start() ray.init("ray://localhost:10011") @@ -325,73 +319,6 @@ def disconnect(middleman): disconnect_thread.join() -def test_disconnects_during_large_get(): - """ - Disconnect repeatedly during a large (multi-chunk) get. - """ - i = 0 - started = False - - def fail_every_three(_): - # Inject an error every third time this method is called - nonlocal i, started - if not started: - return - i += 1 - if i % 3 == 0: - raise RuntimeError - - @ray.remote - def large_result(): - # 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size, - # it will take at least 16 chunks to transfer this object. Since - # the failure is injected every 3 chunks, this transfer can only - # work if the chunked get request retries at the last received chunk - # (instead of starting from the beginning each retry) - return np.random.random((1024, 1024, 128)) - - with start_middleman_server(on_task_response=fail_every_three): - started = True - result = ray.get(large_result.remote()) - assert result.shape == (1024, 1024, 128) - - -def test_disconnects_during_large_async_get(): - """ - Disconnect repeatedly during a large (multi-chunk) async get. - """ - i = 0 - started = False - - def fail_every_three(_): - # Inject an error every third time this method is called - nonlocal i, started - if not started: - return - i += 1 - if i % 3 == 0: - raise RuntimeError - - @ray.remote - def large_result(): - # 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size, - # it will take at least 16 chunks to transfer this object. Since - # the failure is injected every 3 chunks, this transfer can only - # work if the chunked get request retries at the last received chunk - # (instead of starting from the beginning each retry) - return np.random.random((1024, 1024, 128)) - - with start_middleman_server(on_data_response=fail_every_three): - started = True - - async def get_large_result(): - return await large_result.remote() - - loop = asyncio.get_event_loop() - result = loop.run_until_complete(get_large_result()) - assert result.shape == (1024, 1024, 128) - - def test_valid_actor_state(): """ Repeatedly inject errors in the middle of mutating actor calls. Check diff --git a/python/ray/util/client/__init__.py b/python/ray/util/client/__init__.py index bc11c2042321..12cde3b78b57 100644 --- a/python/ray/util/client/__init__.py +++ b/python/ray/util/client/__init__.py @@ -16,7 +16,7 @@ # This version string is incremented to indicate breaking changes in the # protocol that require upgrading the client version. -CURRENT_PROTOCOL_VERSION = "2022-02-14" +CURRENT_PROTOCOL_VERSION = "2021-12-07" class _ClientContext: diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index c8b96e7c1a65..ddb4a57384f5 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -84,12 +84,6 @@ CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100)) -# Large objects are chunked into 64 MiB messages -OBJECT_TRANSFER_CHUNK_SIZE = 64 * 2 ** 20 - -# Warn the user if the object being transferred is larger than 2 GiB -OBJECT_TRANSFER_WARNING_SIZE = 2 * 2 ** 30 - class ClientObjectRef(raylet.ObjectRef): def __init__(self, id: Union[bytes, Future]): @@ -172,8 +166,6 @@ def deserialize_obj( if isinstance(resp, Exception): data = resp - elif isinstance(resp, bytearray): - data = loads_from_server(resp) else: obj = resp.get data = None diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 78bafbbe15ce..44915e908fb7 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -4,7 +4,6 @@ import logging import queue import threading -import warnings import grpc from collections import OrderedDict @@ -12,8 +11,7 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc -from ray.util.client.common import INT32_MAX, OBJECT_TRANSFER_WARNING_SIZE -from ray.util.debug import log_once +from ray.util.client.common import INT32_MAX if TYPE_CHECKING: from ray.util.client.worker import Worker @@ -26,83 +24,6 @@ ACKNOWLEDGE_BATCH_SIZE = 32 -class ChunkCollector: - """ - This object collects chunks from async get requests via __call__, and - calls the underlying callback when the object is fully received, or if an - exception while retrieving the object occurs. - - This is not used in synchronous gets (synchronous gets interact with the - raylet servicer directly, not through the datapath). - - __call__ returns true once the underlying call back has been called. - """ - - def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest): - # Bytearray containing data received so far - self.data = bytearray() - # The callback that will be called once all data is received - self.callback = callback - # The id of the last chunk we've received, or -1 if haven't seen any yet - self.last_seen_chunk = -1 - # The GetRequest that initiated the transfer. start_chunk_id will be - # updated as chunks are received to avoid re-requesting chunks that - # we've already received. - self.request = request - - def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool: - if isinstance(response, Exception): - self.callback(response) - return True - get_resp = response.get - if not get_resp.valid: - self.callback(response) - return True - if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once( - "client_object_transfer_size_warning" - ): - size_gb = get_resp.total_size / 2 ** 30 - warnings.warn( - "Ray Client is attempting to retrieve a " - f"{size_gb:.2f} GiB object over the network, which may " - "be slow. Consider serializing the object to a file and " - "using rsync or S3 instead.", - UserWarning, - ) - chunk_data = get_resp.data - chunk_id = get_resp.chunk_id - if chunk_id == self.last_seen_chunk + 1: - self.data.extend(chunk_data) - self.last_seen_chunk = chunk_id - # If we disconnect partway through, restart the get request - # at the first chunk we haven't seen - self.request.get.start_chunk_id = self.last_seen_chunk + 1 - elif chunk_id > self.last_seen_chunk + 1: - # A chunk was skipped. This shouldn't happen in practice since - # grpc guarantees that chunks will arrive in order. - msg = ( - f"Received chunk {chunk_id} when we expected " - f"{self.last_seen_chunk + 1} for request {response.req_id}" - ) - logger.warning(msg) - self.callback(RuntimeError(msg)) - return True - else: - # We received a chunk that've already seen before. Ignore, since - # it should already be appended to self.data. - logger.debug( - f"Received a repeated chunk {chunk_id} " - f"from request {response.req_id}." - ) - - if get_resp.chunk_id == get_resp.total_chunks - 1: - self.callback(self.data) - return True - else: - # Not done yet - return False - - class DataClient: def __init__(self, client_worker: "Worker", client_id: str, metadata: list): """Initializes a thread-safe datapath over a Ray Client gRPC channel. @@ -198,25 +119,20 @@ def _process_response(self, response: Any) -> None: logger.debug(f"Got unawaited response {response}") return if response.req_id in self.asyncio_waiting_data: - can_remove = True try: - callback = self.asyncio_waiting_data[response.req_id] - if isinstance(callback, ChunkCollector): - can_remove = callback(response) - elif callback: + # NOTE: calling self.asyncio_waiting_data.pop() results + # in the destructor of ClientObjectRef running, which + # calls ReleaseObject(). So self.asyncio_waiting_data + # is accessed without holding self.lock. Holding the + # lock shouldn't be necessary either. + callback = self.asyncio_waiting_data.pop(response.req_id) + if callback: callback(response) - if can_remove: - # NOTE: calling del self.asyncio_waiting_data results - # in the destructor of ClientObjectRef running, which - # calls ReleaseObject(). So self.asyncio_waiting_data - # is accessed without holding self.lock. Holding the - # lock shouldn't be necessary either. - del self.asyncio_waiting_data[response.req_id] except Exception: logger.exception("Callback error:") with self.lock: # Update outstanding requests - if response.req_id in self.outstanding_requests and can_remove: + if response.req_id in self.outstanding_requests: del self.outstanding_requests[response.req_id] # Acknowledge response self._acknowledge(response.req_id) @@ -454,8 +370,7 @@ def RegisterGetCallback( datareq = ray_client_pb2.DataRequest( get=request, ) - collector = ChunkCollector(callback=callback, request=datareq) - self._async_send(datareq, collector) + self._async_send(datareq, callback) # TODO: convert PutObject to async def PutObject( diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 60d969726114..ee7995d05e44 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -530,7 +530,7 @@ def Terminate(self, req, context=None): return self._call_inner_function(req, context, "Terminate") def GetObject(self, request, context=None): - yield from self._call_inner_function(request, context, "GetObject") + return self._call_inner_function(request, context, "GetObject") def PutObject( self, request: ray_client_pb2.PutRequest, context=None diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index 67f28e12d8ad..3653fa47fe94 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -5,7 +5,6 @@ import base64 from collections import defaultdict import functools -import math import queue import pickle @@ -28,7 +27,6 @@ ClientServerHandle, GRPC_OPTIONS, CLIENT_SERVER_MAX_THREADS, - OBJECT_TRANSFER_CHUNK_SIZE, ResponseCache, ) from ray import ray_constants @@ -380,38 +378,20 @@ def _async_get_object( with disable_client_hook(): def send_get_response(result: Any) -> None: - """Pushes GetResponses to the main DataPath loop to send + """Pushes a GetResponse to the main DataPath loop to send to the client. This is called when the object is ready on the server side.""" try: serialized = dumps_from_server(result, client_id, self) - total_size = len(serialized) - assert total_size > 0, "Serialized object cannot be zero bytes" - total_chunks = math.ceil( - total_size / OBJECT_TRANSFER_CHUNK_SIZE + get_resp = ray_client_pb2.GetResponse( + valid=True, data=serialized ) - for chunk_id in range(request.start_chunk_id, total_chunks): - start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE - end = min( - total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE - ) - get_resp = ray_client_pb2.GetResponse( - valid=True, - data=serialized[start:end], - chunk_id=chunk_id, - total_chunks=total_chunks, - total_size=total_size, - ) - chunk_resp = ray_client_pb2.DataResponse( - get=get_resp, req_id=req_id - ) - result_queue.put(chunk_resp) except Exception as exc: get_resp = ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps(exc) ) - resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id) - result_queue.put(resp) + resp = ray_client_pb2.DataResponse(get=get_resp, req_id=req_id) + result_queue.put(resp) ref._on_completed(send_get_response) return None @@ -423,14 +403,13 @@ def GetObject(self, request: ray_client_pb2.GetRequest, context): metadata = {k: v for k, v in context.invocation_metadata()} client_id = metadata.get("client_id") if client_id is None: - yield ray_client_pb2.GetResponse( + return ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps( ValueError("client_id is not specified in request metadata") ), ) - else: - yield from self._get_object(request, client_id) + return self._get_object(request, client_id) def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str): objectrefs = [] @@ -439,7 +418,7 @@ def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str): if ref: objectrefs.append(ref) else: - yield ray_client_pb2.GetResponse( + return ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps( ValueError( @@ -448,28 +427,14 @@ def _get_object(self, request: ray_client_pb2.GetRequest, client_id: str): ) ), ) - return try: logger.debug("get: %s" % objectrefs) with disable_client_hook(): items = ray.get(objectrefs, timeout=request.timeout) except Exception as e: - yield ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e)) - return + return ray_client_pb2.GetResponse(valid=False, error=cloudpickle.dumps(e)) serialized = dumps_from_server(items, client_id, self) - total_size = len(serialized) - assert total_size > 0, "Serialized object cannot be zero bytes" - total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) - for chunk_id in range(request.start_chunk_id, total_chunks): - start = chunk_id * OBJECT_TRANSFER_CHUNK_SIZE - end = min(total_size, (chunk_id + 1) * OBJECT_TRANSFER_CHUNK_SIZE) - yield ray_client_pb2.GetResponse( - valid=True, - data=serialized[start:end], - chunk_id=chunk_id, - total_chunks=total_chunks, - total_size=total_size, - ) + return ray_client_pb2.GetResponse(valid=True, data=serialized) def PutObject( self, request: ray_client_pb2.PutRequest, context=None diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 40e22189fbbf..7a3733578826 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -41,7 +41,6 @@ GRPC_OPTIONS, GRPC_UNRECOVERABLE_ERRORS, INT32_MAX, - OBJECT_TRANSFER_WARNING_SIZE, ) from ray.util.client.dataclient import DataClient from ray.util.client.logsclient import LogstreamClient @@ -310,51 +309,6 @@ def _call_stub(self, stub_name: str, *args, **kwargs) -> Any: continue raise ConnectionError("Client is shutting down.") - def _get_object_iterator( - self, req: ray_client_pb2.GetRequest, *args, **kwargs - ) -> Any: - """ - Calls the stub for GetObject on the underlying server stub. If a - recoverable error occurs while streaming the response, attempts - to retry the get starting from the first chunk that hasn't been - received. - """ - last_seen_chunk = -1 - while not self._in_shutdown: - # If we disconnect partway through, restart the get request - # at the first chunk we haven't seen - req.start_chunk_id = last_seen_chunk + 1 - try: - for chunk in self.server.GetObject(req, *args, **kwargs): - if chunk.chunk_id <= last_seen_chunk: - # Ignore repeat chunks - logger.debug( - f"Received a repeated chunk {chunk.chunk_id} " - f"from request {req.req_id}." - ) - continue - if last_seen_chunk + 1 != chunk.chunk_id: - raise RuntimeError( - f"Received chunk {chunk.chunk_id} when we expected " - f"{self.last_seen_chunk + 1}" - ) - last_seen_chunk = chunk.chunk_id - yield chunk - return - except grpc.RpcError as e: - if self._can_reconnect(e): - time.sleep(0.5) - continue - raise - except ValueError: - # Trying to use the stub on a cancelled channel will raise - # ValueError. This should only happen when the data client - # is attempting to reset the connection -- sleep and try - # again. - time.sleep(0.5) - continue - raise ConnectionError("Client is shutting down.") - def _add_ids_to_metadata(self, metadata: Any): """ Adds a unique req_id and the current thread's identifier to the @@ -445,33 +399,18 @@ def get(self, vals, *, timeout: Optional[float] = None) -> Any: def _get(self, ref: List[ClientObjectRef], timeout: float): req = ray_client_pb2.GetRequest(ids=[r.id for r in ref], timeout=timeout) - data = bytearray() try: - resp = self._get_object_iterator(req, metadata=self.metadata) - for chunk in resp: - if not chunk.valid: - try: - err = cloudpickle.loads(chunk.error) - except (pickle.UnpicklingError, TypeError): - logger.exception("Failed to deserialize {}".format(chunk.error)) - raise - raise err - if chunk.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once( - "client_object_transfer_size_warning" - ): - size_gb = chunk.total_size / 2 ** 30 - warnings.warn( - "Ray Client is attempting to retrieve a " - f"{size_gb:.2f} GiB object over the network, which may " - "be slow. Consider serializing the object to a file " - "and using S3 or rsync instead.", - UserWarning, - stacklevel=5, - ) - data.extend(chunk.data) + resp = self._call_stub("GetObject", req, metadata=self.metadata) except grpc.RpcError as e: raise decode_exception(e) - return loads_from_server(data) + if not resp.valid: + try: + err = cloudpickle.loads(resp.error) + except (pickle.UnpicklingError, TypeError): + logger.exception("Failed to deserialize {}".format(resp.error)) + raise + raise err + return loads_from_server(resp.data) def put(self, val, *, client_ref_id: bytes = None): if isinstance(val, ClientObjectRef): diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index d94f85002d6d..e80f9d6a32f2 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -122,9 +122,6 @@ message GetRequest { float timeout = 2; // Whether to schedule this as a callback on the server side. bool asynchronous = 3; - // The chunk_id to start retrieving data from, in case the request is interrupted - // after partial retrieval by a disconnect - int32 start_chunk_id = 5; // Deprecated fields. bytes id = 1 [deprecated = true]; @@ -138,12 +135,6 @@ message GetResponse { bytes data = 2; // An error blob (for example, an exception) on failure. bytes error = 3; - // Identifies which chunk the data belongs to - int32 chunk_id = 4; - // Total number of chunks - int32 total_chunks = 5; - // Total size in bytes of the data being retrieved - uint64 total_size = 6; } // Waits for data to be ready on the server, with a timeout. @@ -296,7 +287,8 @@ service RayletDriver { } rpc PrepRuntimeEnv(PrepRuntimeEnvRequest) returns (PrepRuntimeEnvResponse) { } - rpc GetObject(GetRequest) returns (stream GetResponse) {} + rpc GetObject(GetRequest) returns (GetResponse) { + } rpc PutObject(PutRequest) returns (PutResponse) { } rpc WaitObject(WaitRequest) returns (WaitResponse) {