Skip to content

Commit

Permalink
change all callbacks to move to save copies
Browse files Browse the repository at this point in the history
Signed-off-by: Ruiyang Wang <[email protected]>

fix cpp test

Signed-off-by: Ruiyang Wang <[email protected]>

more moves

Signed-off-by: Ruiyang Wang <[email protected]>

change all callbacks to move to save copies

Signed-off-by: Ruiyang Wang <[email protected]>

move in cython (unfortunately not 0 copy)

Signed-off-by: Ruiyang Wang <[email protected]>
  • Loading branch information
rynewang committed Aug 6, 2024
1 parent 492cc1b commit 3cd6f81
Show file tree
Hide file tree
Showing 22 changed files with 190 additions and 180 deletions.
60 changes: 27 additions & 33 deletions python/ray/includes/gcs_client.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ cdef class NewGcsClient:
status = self.inner.get().InternalKV().Get(
ns, key, timeout_ms, opt_value.value())
return raise_or_return(
convert_optional_str_none_for_not_found(status, opt_value))
convert_optional_str_none_for_not_found(status, move(opt_value)))

def internal_kv_multi_get(
self, keys: List[bytes], namespace=None, timeout=None
Expand All @@ -110,7 +110,7 @@ cdef class NewGcsClient:
with nogil:
status = self.inner.get().InternalKV().MultiGet(
ns, c_keys, timeout_ms, opt_values.value())
return raise_or_return(convert_optional_multi_get(status, opt_values))
return raise_or_return(convert_optional_multi_get(status, move(opt_values)))

def internal_kv_put(self, c_string key, c_string value, c_bool overwrite=False,
namespace=None, timeout=None) -> int:
Expand All @@ -125,7 +125,7 @@ cdef class NewGcsClient:
with nogil:
status = self.inner.get().InternalKV().Put(
ns, key, value, overwrite, timeout_ms, opt_added.value())
added = raise_or_return(convert_optional_bool(status, opt_added))
added = raise_or_return(convert_optional_bool(status, move(opt_added)))
return 1 if added else 0

def internal_kv_del(self, c_string key, c_bool del_by_prefix,
Expand All @@ -141,7 +141,7 @@ cdef class NewGcsClient:
with nogil:
status = self.inner.get().InternalKV().Del(
ns, key, del_by_prefix, timeout_ms, opt_num_deleted.value())
return raise_or_return(convert_optional_int(status, opt_num_deleted))
return raise_or_return(convert_optional_int(status, move(opt_num_deleted)))

def internal_kv_keys(
self, c_string prefix, namespace=None, timeout=None
Expand All @@ -154,7 +154,7 @@ cdef class NewGcsClient:
with nogil:
status = self.inner.get().InternalKV().Keys(
ns, prefix, timeout_ms, opt_keys.value())
return raise_or_return(convert_optional_vector_str(status, opt_keys))
return raise_or_return(convert_optional_vector_str(status, move(opt_keys)))

def internal_kv_exists(self, c_string key, namespace=None, timeout=None) -> bool:
cdef:
Expand All @@ -165,7 +165,7 @@ cdef class NewGcsClient:
with nogil:
status = self.inner.get().InternalKV().Exists(
ns, key, timeout_ms, opt_exists.value())
return raise_or_return(convert_optional_bool(status, opt_exists))
return raise_or_return(convert_optional_bool(status, move(opt_exists)))

#############################################################
# Internal KV async methods
Expand Down Expand Up @@ -678,7 +678,7 @@ cdef convert_status(CRayStatus status) with gil:
except Exception as e:
return None, e
cdef convert_optional_str_none_for_not_found(
CRayStatus status, const optional[c_string]& c_str) with gil:
CRayStatus status, optional[c_string]&& c_str) with gil:
# If status is NotFound, return None.
# If status is OK, return the value.
# Else, raise exception.
Expand All @@ -688,31 +688,29 @@ cdef convert_optional_str_none_for_not_found(
return None, None
check_status_timeout_as_rpc_error(status)
assert c_str.has_value()
return c_str.value(), None
return move(c_str.value()), None
except Exception as e:
return None, e

cdef convert_optional_multi_get(
CRayStatus status,
const optional[unordered_map[c_string, c_string]]& c_map) with gil:
optional[unordered_map[c_string, c_string]]&& c_map) with gil:
# -> Dict[str, str]
cdef unordered_map[c_string, c_string].const_iterator it
cdef unordered_map[c_string, c_string].iterator it
try:
check_status_timeout_as_rpc_error(status)
assert c_map.has_value()

result = {}
it = c_map.value().const_begin()
while it != c_map.value().const_end():
key = dereference(it).first
value = dereference(it).second
result[key] = value
it = c_map.value().begin()
while it != c_map.value().end():
result[dereference(it).first] = move(dereference(it).second)
postincrement(it)
return result, None
except Exception as e:
return None, e

cdef convert_optional_int(CRayStatus status, const optional[int]& c_int) with gil:
cdef convert_optional_int(CRayStatus status, optional[int]&& c_int) with gil:
# -> int
try:
check_status_timeout_as_rpc_error(status)
Expand All @@ -722,26 +720,15 @@ cdef convert_optional_int(CRayStatus status, const optional[int]& c_int) with gi
return None, e

cdef convert_optional_vector_str(
CRayStatus status, const optional[c_vector[c_string]]& c_vec) with gil:
# -> Dict[str, str]
cdef const c_vector[c_string]* vec
cdef c_vector[c_string].const_iterator it
CRayStatus status, optional[c_vector[c_string]]&& c_vec) with gil:
# -> List[bytes]
try:
check_status_timeout_as_rpc_error(status)

assert c_vec.has_value()
vec = &c_vec.value()
it = vec.const_begin()
result = []
while it != dereference(vec).const_end():
result.append(dereference(it))
postincrement(it)
return result, None
return convert_multi_str(status, move(c_vec.value()))
except Exception as e:
return None, e


cdef convert_optional_bool(CRayStatus status, const optional[c_bool]& b) with gil:
cdef convert_optional_bool(CRayStatus status, optional[c_bool]&& b) with gil:
# -> bool
try:
check_status_timeout_as_rpc_error(status)
Expand All @@ -758,10 +745,17 @@ cdef convert_multi_bool(CRayStatus status, c_vector[c_bool]&& c_data) with gil:
except Exception as e:
return None, e

cdef convert_multi_str(CRayStatus status, c_vector[c_string]&& c_data) with gil:
cdef convert_multi_str(CRayStatus status, c_vector[c_string]&& vec) with gil:
# -> List[bytes]
cdef c_vector[c_string].iterator it
try:
check_status_timeout_as_rpc_error(status)
return [datum for datum in c_data], None

it = vec.begin()
result = []
while it != vec.end():
result.append(move(dereference(it)))
postincrement(it)
return result, None
except Exception as e:
return None, e
3 changes: 2 additions & 1 deletion src/ray/core_worker/test/reference_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ class MockDistributedPublisher : public pubsub::PublisherInterface {
if (it != subscription_callback_map_->end()) {
const auto callback_it = it->second.find(oid);
RAY_CHECK(callback_it != it->second.end());
callback_it->second(pub_message);
rpc::PubMessage copied = pub_message;
callback_it->second(std::move(copied));
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/ray/core_worker/transport/actor_task_submitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ void ActorTaskSubmitter::FailInflightTasks(
// network issue. We don't call `task_finisher_.FailOrRetryPendingTask` directly because
// there's much more work to do in the callback.
auto status = Status::IOError("Fail all inflight tasks due to actor state change.");
rpc::PushTaskReply reply;
for (const auto &[_, callback] : inflight_task_callbacks) {
callback(status, reply);
callback(status, rpc::PushTaskReply());
}
}

Expand Down Expand Up @@ -484,7 +483,7 @@ void ActorTaskSubmitter::PushActorTask(ClientQueue &queue,

queue.inflight_task_callbacks.emplace(task_id, std::move(reply_callback));
rpc::ClientCallback<rpc::PushTaskReply> wrapped_callback =
[this, task_id, actor_id](const Status &status, const rpc::PushTaskReply &reply) {
[this, task_id, actor_id](const Status &status, rpc::PushTaskReply &&reply) {
rpc::ClientCallback<rpc::PushTaskReply> reply_callback;
{
absl::MutexLock lock(&mu_);
Expand All @@ -500,7 +499,7 @@ void ActorTaskSubmitter::PushActorTask(ClientQueue &queue,
reply_callback = std::move(callback_it->second);
queue.inflight_task_callbacks.erase(callback_it);
}
reply_callback(status, reply);
reply_callback(status, std::move(reply));
};

task_finisher_.MarkTaskWaitingForExecution(task_id,
Expand Down
7 changes: 4 additions & 3 deletions src/ray/gcs/callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ using StatusCallback = std::function<void(Status status)>;
/// \param status Status indicates whether the read was successful.
/// \param result The item returned by GCS. If the item to read doesn't exist,
/// this optional object is empty.
/// TODO(ryw): make an Either union type to avoid the optional.
template <typename Data>
using OptionalItemCallback =
std::function<void(Status status, const std::optional<Data> &result)>;
std::function<void(Status status, std::optional<Data> &&result)>;

/// This callback is used to receive multiple items from GCS when a read completes.
/// \param status Status indicates whether the read was successful.
Expand All @@ -48,12 +49,12 @@ using MultiItemCallback = std::function<void(Status status, std::vector<Data> &&
/// \param id The id of the item.
/// \param result The notification message.
template <typename ID, typename Data>
using SubscribeCallback = std::function<void(const ID &id, const Data &result)>;
using SubscribeCallback = std::function<void(const ID &id, Data &&result)>;

/// This callback is used to receive a single item from GCS.
/// \param result The item returned by GCS.
template <typename Data>
using ItemCallback = std::function<void(const Data &result)>;
using ItemCallback = std::function<void(Data &&result)>;

/// This callback is used to receive multiple key-value items from GCS.
/// \param result The key-value items returned by GCS.
Expand Down
Loading

0 comments on commit 3cd6f81

Please sign in to comment.