Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[Client] chunked get requests" #22455

Merged
merged 1 commit into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 4 additions & 77 deletions python/ray/tests/test_client_reconnect.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from concurrent import futures
import asyncio
import contextlib
import os
import threading
import sys
import grpc
import numpy as np

import time
import random
Expand Down Expand Up @@ -127,8 +125,7 @@ def _call_inner_function(
context.set_code(e.code())
context.set_details(e.details())
raise
if self.on_response and method != "GetObject":
# GetObject streams response, handle on_response separately
if self.on_response:
self.on_response(response)
return response

Expand Down Expand Up @@ -162,10 +159,7 @@ def Terminate(self, req, context=None):
return self._call_inner_function(req, context, "Terminate")

def GetObject(self, request, context=None):
for response in self._call_inner_function(request, context, "GetObject"):
if self.on_response:
self.on_response(response)
yield response
return self._call_inner_function(request, context, "GetObject")

def PutObject(
self, request: ray_client_pb2.PutRequest, context=None
Expand Down Expand Up @@ -276,8 +270,8 @@ def start_middleman_server(
real_addr="localhost:50051",
on_log_response=on_log_response,
on_data_response=on_data_response,
on_task_request=on_task_request,
on_task_response=on_task_response,
on_task_request=on_task_response,
on_task_response=on_task_request,
)
middleman.start()
ray.init("ray://localhost:10011")
Expand Down Expand Up @@ -325,73 +319,6 @@ def disconnect(middleman):
disconnect_thread.join()


def test_disconnects_during_large_get():
"""
Disconnect repeatedly during a large (multi-chunk) get.
"""
i = 0
started = False

def fail_every_three(_):
# Inject an error every third time this method is called
nonlocal i, started
if not started:
return
i += 1
if i % 3 == 0:
raise RuntimeError

@ray.remote
def large_result():
# 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size,
# it will take at least 16 chunks to transfer this object. Since
# the failure is injected every 3 chunks, this transfer can only
# work if the chunked get request retries at the last received chunk
# (instead of starting from the beginning each retry)
return np.random.random((1024, 1024, 128))

with start_middleman_server(on_task_response=fail_every_three):
started = True
result = ray.get(large_result.remote())
assert result.shape == (1024, 1024, 128)


def test_disconnects_during_large_async_get():
"""
Disconnect repeatedly during a large (multi-chunk) async get.
"""
i = 0
started = False

def fail_every_three(_):
# Inject an error every third time this method is called
nonlocal i, started
if not started:
return
i += 1
if i % 3 == 0:
raise RuntimeError

@ray.remote
def large_result():
# 1024x1024x128 float64 matrix (1024 MiB). With 64MiB chunk size,
# it will take at least 16 chunks to transfer this object. Since
# the failure is injected every 3 chunks, this transfer can only
# work if the chunked get request retries at the last received chunk
# (instead of starting from the beginning each retry)
return np.random.random((1024, 1024, 128))

with start_middleman_server(on_data_response=fail_every_three):
started = True

async def get_large_result():
return await large_result.remote()

loop = asyncio.get_event_loop()
result = loop.run_until_complete(get_large_result())
assert result.shape == (1024, 1024, 128)


def test_valid_actor_state():
"""
Repeatedly inject errors in the middle of mutating actor calls. Check
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# This version string is incremented to indicate breaking changes in the
# protocol that require upgrading the client version.
CURRENT_PROTOCOL_VERSION = "2022-02-14"
CURRENT_PROTOCOL_VERSION = "2021-12-07"


class _ClientContext:
Expand Down
8 changes: 0 additions & 8 deletions python/ray/util/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@

CLIENT_SERVER_MAX_THREADS = float(os.getenv("RAY_CLIENT_SERVER_MAX_THREADS", 100))

# Large objects are chunked into 64 MiB messages
OBJECT_TRANSFER_CHUNK_SIZE = 64 * 2 ** 20

# Warn the user if the object being transferred is larger than 2 GiB
OBJECT_TRANSFER_WARNING_SIZE = 2 * 2 ** 30


class ClientObjectRef(raylet.ObjectRef):
def __init__(self, id: Union[bytes, Future]):
Expand Down Expand Up @@ -172,8 +166,6 @@ def deserialize_obj(

if isinstance(resp, Exception):
data = resp
elif isinstance(resp, bytearray):
data = loads_from_server(resp)
else:
obj = resp.get
data = None
Expand Down
105 changes: 10 additions & 95 deletions python/ray/util/client/dataclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import logging
import queue
import threading
import warnings
import grpc

from collections import OrderedDict
from typing import Any, Callable, Dict, TYPE_CHECKING, Optional, Union

import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.util.client.common import INT32_MAX, OBJECT_TRANSFER_WARNING_SIZE
from ray.util.debug import log_once
from ray.util.client.common import INT32_MAX

if TYPE_CHECKING:
from ray.util.client.worker import Worker
Expand All @@ -26,83 +24,6 @@
ACKNOWLEDGE_BATCH_SIZE = 32


class ChunkCollector:
"""
This object collects chunks from async get requests via __call__, and
calls the underlying callback when the object is fully received, or if an
exception while retrieving the object occurs.

This is not used in synchronous gets (synchronous gets interact with the
raylet servicer directly, not through the datapath).

__call__ returns true once the underlying call back has been called.
"""

def __init__(self, callback: ResponseCallable, request: ray_client_pb2.DataRequest):
# Bytearray containing data received so far
self.data = bytearray()
# The callback that will be called once all data is received
self.callback = callback
# The id of the last chunk we've received, or -1 if haven't seen any yet
self.last_seen_chunk = -1
# The GetRequest that initiated the transfer. start_chunk_id will be
# updated as chunks are received to avoid re-requesting chunks that
# we've already received.
self.request = request

def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> bool:
if isinstance(response, Exception):
self.callback(response)
return True
get_resp = response.get
if not get_resp.valid:
self.callback(response)
return True
if get_resp.total_size > OBJECT_TRANSFER_WARNING_SIZE and log_once(
"client_object_transfer_size_warning"
):
size_gb = get_resp.total_size / 2 ** 30
warnings.warn(
"Ray Client is attempting to retrieve a "
f"{size_gb:.2f} GiB object over the network, which may "
"be slow. Consider serializing the object to a file and "
"using rsync or S3 instead.",
UserWarning,
)
chunk_data = get_resp.data
chunk_id = get_resp.chunk_id
if chunk_id == self.last_seen_chunk + 1:
self.data.extend(chunk_data)
self.last_seen_chunk = chunk_id
# If we disconnect partway through, restart the get request
# at the first chunk we haven't seen
self.request.get.start_chunk_id = self.last_seen_chunk + 1
elif chunk_id > self.last_seen_chunk + 1:
# A chunk was skipped. This shouldn't happen in practice since
# grpc guarantees that chunks will arrive in order.
msg = (
f"Received chunk {chunk_id} when we expected "
f"{self.last_seen_chunk + 1} for request {response.req_id}"
)
logger.warning(msg)
self.callback(RuntimeError(msg))
return True
else:
# We received a chunk that've already seen before. Ignore, since
# it should already be appended to self.data.
logger.debug(
f"Received a repeated chunk {chunk_id} "
f"from request {response.req_id}."
)

if get_resp.chunk_id == get_resp.total_chunks - 1:
self.callback(self.data)
return True
else:
# Not done yet
return False


class DataClient:
def __init__(self, client_worker: "Worker", client_id: str, metadata: list):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
Expand Down Expand Up @@ -198,25 +119,20 @@ def _process_response(self, response: Any) -> None:
logger.debug(f"Got unawaited response {response}")
return
if response.req_id in self.asyncio_waiting_data:
can_remove = True
try:
callback = self.asyncio_waiting_data[response.req_id]
if isinstance(callback, ChunkCollector):
can_remove = callback(response)
elif callback:
# NOTE: calling self.asyncio_waiting_data.pop() results
# in the destructor of ClientObjectRef running, which
# calls ReleaseObject(). So self.asyncio_waiting_data
# is accessed without holding self.lock. Holding the
# lock shouldn't be necessary either.
callback = self.asyncio_waiting_data.pop(response.req_id)
if callback:
callback(response)
if can_remove:
# NOTE: calling del self.asyncio_waiting_data results
# in the destructor of ClientObjectRef running, which
# calls ReleaseObject(). So self.asyncio_waiting_data
# is accessed without holding self.lock. Holding the
# lock shouldn't be necessary either.
del self.asyncio_waiting_data[response.req_id]
except Exception:
logger.exception("Callback error:")
with self.lock:
# Update outstanding requests
if response.req_id in self.outstanding_requests and can_remove:
if response.req_id in self.outstanding_requests:
del self.outstanding_requests[response.req_id]
# Acknowledge response
self._acknowledge(response.req_id)
Expand Down Expand Up @@ -454,8 +370,7 @@ def RegisterGetCallback(
datareq = ray_client_pb2.DataRequest(
get=request,
)
collector = ChunkCollector(callback=callback, request=datareq)
self._async_send(datareq, collector)
self._async_send(datareq, callback)

# TODO: convert PutObject to async
def PutObject(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/server/proxier.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def Terminate(self, req, context=None):
return self._call_inner_function(req, context, "Terminate")

def GetObject(self, request, context=None):
yield from self._call_inner_function(request, context, "GetObject")
return self._call_inner_function(request, context, "GetObject")

def PutObject(
self, request: ray_client_pb2.PutRequest, context=None
Expand Down
Loading