Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] microtvm: fix RPC large transfer size issue #7838

Merged
merged 12 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apps/bundle_deploy/crt_config/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,7 @@
/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 512

#endif // TVM_RUNTIME_CRT_CONFIG_H_
26 changes: 23 additions & 3 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index
}

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name,
out);
tvm_crt_error_t to_return =
FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out);
// For compatibility with the C++ runtime equivalent, in src/runtime/registry.cc.
if (to_return == kTvmErrorFunctionNameNotFound) {
*out = NULL;
to_return = kTvmErrorNoError;
}
return to_return;
}

int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
Expand Down Expand Up @@ -343,7 +349,6 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
if (to_return == kTvmErrorFunctionNameNotFound) {
to_return = kTvmErrorNoError;
}

return to_return;
}

Expand Down Expand Up @@ -372,6 +377,17 @@ int TVMFuncFree(TVMFunctionHandle func) {

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);

// Sends CRT max packet size.
int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value,
int* ret_type_codes) {
// 11 bytes is for microtvm overhead:
// packet start(2), length(4), session header(3), crc(2)
ret_value[0].v_int64 = TVM_CRT_MAX_PACKET_SIZE_BYTES - 11;
ret_type_codes[0] = kTVMArgInt;
return 0;
}

tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
Expand Down Expand Up @@ -412,6 +428,10 @@ tvm_crt_error_t TVMInitializeRuntime() {
error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0);
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0);
}

if (error != kTvmErrorNoError) {
TVMPlatformMemoryFree(registry_backing_memory, dev);
TVMPlatformMemoryFree(func_registry_memory, dev);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/crt/host/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
#define TVM_CRT_MAX_REGISTERED_MODULES 2

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024

/*! \brief Maximum length of a PackedFunc function name. */
#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30
Expand Down
9 changes: 6 additions & 3 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ int main(int argc, char** argv) {
"failed to register GraphExecutor TVMModule");
#endif

if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server,
0)) {
fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n");
int error = TVMFuncRegisterGlobal("tvm.testing.reset_server",
(TVMFunctionHandle)&testonly_reset_server, 0);
if (error) {
fprintf(stderr,
"utvm runtime: internal error (error#: %x) registering global packedfunc; exiting\n",
error);
return 2;
}

Expand Down
3 changes: 3 additions & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace runtime {
/*! \brief The current RPC procotol version. */
constexpr const char* kRPCProtocolVer = "0.8.0";

// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;

/*! \brief The RPC code */
enum class RPCCode : int {
kNone,
Expand Down
108 changes: 90 additions & 18 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
}

/*!
* \brief Recive incoming packed seq from the stream.
* \brief Receive incoming packed seq from the stream.
* \return The received argments.
* \note The TVMArgs is available until we switchstate.
*/
Expand Down Expand Up @@ -369,7 +369,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
*/
void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) {
TVMArgs args = RecvPackedSeq();

if (code == RPCCode::kException) {
// switch to the state before sending exception.
this->SwitchToState(kRecvPacketNumBytes);
Expand Down Expand Up @@ -801,14 +800,13 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyToRemote;

uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_EQ(nbytes, num_data_bytes);
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes)
<< "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";

uint64_t to_data = reinterpret_cast<uint64_t>(to->data);
uint64_t shape_bytes = to->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) +
sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes +
sizeof(nbytes) + num_data_bytes;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes);
uint64_t packet_nbytes = overhead + nbytes;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand All @@ -822,14 +820,13 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyFromRemote;

uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*from));
CHECK_EQ(nbytes, num_data_bytes);
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*from));
ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes)
<< "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";

uint64_t from_data = reinterpret_cast<uint64_t>(from->data);
uint64_t shape_bytes = from->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(from_data) + sizeof(from->device) +
sizeof(from->ndim) + sizeof(from->dtype) + sizeof(from->byte_offset) +
shape_bytes + sizeof(nbytes);
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes);
uint64_t packet_nbytes = overhead;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand Down Expand Up @@ -981,11 +978,55 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
}

void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes);
RPCCode code = RPCCode::kCopyToRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes);
uint64_t rpc_max_size = GetRPCMaxTransferSize();
ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!";
const uint64_t block_size = rpc_max_size - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* from_bytes;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_to->byte_offset = block_count * block_size;
from_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
endpoint_->CopyToRemote(from_bytes, remote_to, block_size);
}

const uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_to->byte_offset = block_count * block_size;
from_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes);
}
}

void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {
endpoint_->CopyFromRemote(remote_from, local_to_bytes, nbytes);
RPCCode code = RPCCode::kCopyFromRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes);
uint64_t rpc_max_size = GetRPCMaxTransferSize();
ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!";
const uint64_t block_size = rpc_max_size - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* to_bytes;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_from->byte_offset = block_count * block_size;
to_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
endpoint_->CopyFromRemote(remote_from, to_bytes, block_size);
}

const uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_from->byte_offset = block_count * block_size;
to_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes);
}
}

void FreeHandle(void* handle, int type_code) final {
Expand Down Expand Up @@ -1042,12 +1083,43 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
bool IsLocalSession() const final { return false; }

private:
uint64_t GetRPCMaxTransferSize() {
if (rpc_chunk_max_size_bytes_ > 0) {
return (uint64_t)rpc_chunk_max_size_bytes_;
}

PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize");
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
if (rpc_func == nullptr) {
rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault;
} else {
CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
// Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
ICHECK_GT(rpc_chunk_max_size_bytes_, 0)
<< "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_
<< ")";
});
}
return (uint64_t)rpc_chunk_max_size_bytes_;
}

std::shared_ptr<RPCEndpoint> endpoint_;
int64_t rpc_chunk_max_size_bytes_ = -1;
};

std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
return std::make_shared<RPCClientSession>(endpoint);
}

uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) {
uint64_t shape_bytes = tensor->ndim * sizeof(int64_t);
uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<uint8_t*>(tensor->data));
uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) +
sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) +
shape_bytes + sizeof(nbytes);
return overhead;
}

} // namespace runtime
} // namespace tvm
10 changes: 10 additions & 0 deletions src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ template <typename... Args>
inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
}

/*!
* \brief Calculates overhead size of a CopyToRemote packet.
* \param to DLTensor to copy.
* \param code RPCCode for this transfer.
* \param nbytes Number of bytes to transfer.
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
* \return The remote-copy packet overhead size.
*/
uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes);

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_