Skip to content

Commit

Permalink
[RPC] microtvm: fix RPC large transfer size issue (apache#7838)
Browse files Browse the repository at this point in the history
* fix rpc for microtvm

* apply feedbacks

* bundle deploy fix

* fix func registry size

* mv constant

* fix copyfromremote

* address comments and fix error

* change rpc default max size

* Trigger Build

* add checks

* Trigger Build

* fix ICHECK
  • Loading branch information
mehrdadh authored and Trevor Morris committed May 6, 2021
1 parent da710a4 commit 90c862c
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 26 deletions.
3 changes: 3 additions & 0 deletions apps/bundle_deploy/crt_config/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,7 @@
/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 512

#endif // TVM_RUNTIME_CRT_CONFIG_H_
26 changes: 23 additions & 3 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index
}

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name,
out);
tvm_crt_error_t to_return =
FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out);
// For compatibility with the C++ runtime equivalent, in src/runtime/registry.cc.
if (to_return == kTvmErrorFunctionNameNotFound) {
*out = NULL;
to_return = kTvmErrorNoError;
}
return to_return;
}

int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
Expand Down Expand Up @@ -352,7 +358,6 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
if (to_return == kTvmErrorFunctionNameNotFound) {
to_return = kTvmErrorNoError;
}

return to_return;
}

Expand Down Expand Up @@ -381,6 +386,17 @@ int TVMFuncFree(TVMFunctionHandle func) {

int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
int* ret_type_code);

// Sends CRT max packet size.
int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value,
int* ret_type_codes) {
// 11 bytes is for microtvm overhead:
// packet start(2), length(4), session header(3), crc(2)
ret_value[0].v_int64 = TVM_CRT_MAX_PACKET_SIZE_BYTES - 11;
ret_type_codes[0] = kTVMArgInt;
return 0;
}

tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
Expand Down Expand Up @@ -421,6 +437,10 @@ tvm_crt_error_t TVMInitializeRuntime() {
error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0);
}

if (error == kTvmErrorNoError) {
error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0);
}

if (error != kTvmErrorNoError) {
TVMPlatformMemoryFree(registry_backing_memory, dev);
TVMPlatformMemoryFree(func_registry_memory, dev);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/crt/host/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@
#define TVM_CRT_MAX_REGISTERED_MODULES 2

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024

/*! \brief Maximum length of a PackedFunc function name. */
#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30
Expand Down
9 changes: 6 additions & 3 deletions src/runtime/crt/host/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ int main(int argc, char** argv) {
"failed to register GraphExecutor TVMModule");
#endif

if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server,
0)) {
fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n");
int error = TVMFuncRegisterGlobal("tvm.testing.reset_server",
(TVMFunctionHandle)&testonly_reset_server, 0);
if (error) {
fprintf(stderr,
"utvm runtime: internal error (error#: %x) registering global packedfunc; exiting\n",
error);
return 2;
}

Expand Down
3 changes: 3 additions & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace runtime {
/*! \brief The current RPC procotol version. */
constexpr const char* kRPCProtocolVer = "0.8.0";

// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;

/*! \brief The RPC code */
enum class RPCCode : int {
kNone,
Expand Down
108 changes: 90 additions & 18 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
}

/*!
* \brief Recive incoming packed seq from the stream.
* \brief Receive incoming packed seq from the stream.
* \return The received argments.
* \note The TVMArgs is available until we switchstate.
*/
Expand Down Expand Up @@ -369,7 +369,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
*/
void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) {
TVMArgs args = RecvPackedSeq();

if (code == RPCCode::kException) {
// switch to the state before sending exception.
this->SwitchToState(kRecvPacketNumBytes);
Expand Down Expand Up @@ -802,14 +801,13 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyToRemote;

uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_EQ(nbytes, num_data_bytes);
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes)
<< "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";

uint64_t to_data = reinterpret_cast<uint64_t>(to->data);
uint64_t shape_bytes = to->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) +
sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes +
sizeof(nbytes) + num_data_bytes;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes);
uint64_t packet_nbytes = overhead + nbytes;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand All @@ -823,14 +821,13 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes
std::lock_guard<std::mutex> lock(mutex_);
RPCCode code = RPCCode::kCopyFromRemote;

uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*from));
CHECK_EQ(nbytes, num_data_bytes);
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*from));
ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes)
<< "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";

uint64_t from_data = reinterpret_cast<uint64_t>(from->data);
uint64_t shape_bytes = from->ndim * sizeof(int64_t);
uint64_t packet_nbytes = sizeof(code) + sizeof(from_data) + sizeof(from->device) +
sizeof(from->ndim) + sizeof(from->dtype) + sizeof(from->byte_offset) +
shape_bytes + sizeof(nbytes);
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes);
uint64_t packet_nbytes = overhead;

handler_->Write(packet_nbytes);
handler_->Write(code);
Expand Down Expand Up @@ -1009,11 +1006,55 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
}

void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes);
RPCCode code = RPCCode::kCopyToRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes);
uint64_t rpc_max_size = GetRPCMaxTransferSize();
ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!";
const uint64_t block_size = rpc_max_size - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* from_bytes;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_to->byte_offset = block_count * block_size;
from_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
endpoint_->CopyToRemote(from_bytes, remote_to, block_size);
}

const uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_to->byte_offset = block_count * block_size;
from_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes);
}
}

void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {
endpoint_->CopyFromRemote(remote_from, local_to_bytes, nbytes);
RPCCode code = RPCCode::kCopyFromRemote;
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes);
uint64_t rpc_max_size = GetRPCMaxTransferSize();
ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!";
const uint64_t block_size = rpc_max_size - overhead;
uint64_t block_count = 0;
const uint64_t num_blocks = nbytes / block_size;
void* to_bytes;

for (block_count = 0; block_count < num_blocks; block_count++) {
remote_from->byte_offset = block_count * block_size;
to_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
endpoint_->CopyFromRemote(remote_from, to_bytes, block_size);
}

const uint64_t remainder_bytes = nbytes % block_size;
if (remainder_bytes != 0) {
remote_from->byte_offset = block_count * block_size;
to_bytes = reinterpret_cast<void*>(
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes);
}
}

void FreeHandle(void* handle, int type_code) final {
Expand Down Expand Up @@ -1082,12 +1123,43 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
bool IsLocalSession() const final { return false; }

private:
uint64_t GetRPCMaxTransferSize() {
if (rpc_chunk_max_size_bytes_ > 0) {
return (uint64_t)rpc_chunk_max_size_bytes_;
}

PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize");
if (rpc_func == nullptr) {
rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault;
} else {
CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) {
// Use args[1] as return value, args[0] is tcode
// Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
ICHECK_GT(rpc_chunk_max_size_bytes_, 0)
<< "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_
<< ")";
});
}
return (uint64_t)rpc_chunk_max_size_bytes_;
}

std::shared_ptr<RPCEndpoint> endpoint_;
int64_t rpc_chunk_max_size_bytes_ = -1;
};

std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
return std::make_shared<RPCClientSession>(endpoint);
}

uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) {
uint64_t shape_bytes = tensor->ndim * sizeof(int64_t);
uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<uint8_t*>(tensor->data));
uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) +
sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) +
shape_bytes + sizeof(nbytes);
return overhead;
}

} // namespace runtime
} // namespace tvm
10 changes: 10 additions & 0 deletions src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ template <typename... Args>
inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
}

/*!
* \brief Calculates overhead size of a CopyToRemote packet.
* \param to DLTensor to copy.
* \param code RPCCode for this transfer.
* \param nbytes Number of bytes to transfer.
* \return The remote-copy packet overhead size.
*/
uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes);

} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_

0 comments on commit 90c862c

Please sign in to comment.