diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index 97b6c2103f4b..11086c0e9a15 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -45,4 +45,7 @@ /*! Size of the global function registry, in bytes. */ #define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +/*! Maximum packet size, in bytes, including the length header. */ +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 512 + #endif // TVM_RUNTIME_CRT_CONFIG_H_ diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 172972cc57b9..f73449829bd6 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -307,8 +307,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index } int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { - return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, - out); + tvm_crt_error_t to_return = + FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out); + // For compatibility with the C++ runtime equivalent, in src/runtime/registry.cc. + if (to_return == kTvmErrorFunctionNameNotFound) { + *out = NULL; + to_return = kTvmErrorNoError; + } + return to_return; } int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, @@ -352,7 +358,6 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r if (to_return == kTvmErrorFunctionNameNotFound) { to_return = kTvmErrorNoError; } - return to_return; } @@ -381,6 +386,17 @@ int TVMFuncFree(TVMFunctionHandle func) { int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, int* ret_type_code); + +// Sends CRT max packet size. +int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value, + int* ret_type_codes) { + // 11 bytes is for microtvm overhead: + // packet start(2), length(4), session header(3), crc(2) + ret_value[0].v_int64 = TVM_CRT_MAX_PACKET_SIZE_BYTES - 11; + ret_type_codes[0] = kTVMArgInt; + return 0; +} + tvm_crt_error_t TVMInitializeRuntime() { int idx = 0; tvm_crt_error_t error = kTvmErrorNoError; @@ -421,6 +437,10 @@ tvm_crt_error_t TVMInitializeRuntime() { error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0); } + if (error == kTvmErrorNoError) { + error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0); + } + if (error != kTvmErrorNoError) { TVMPlatformMemoryFree(registry_backing_memory, dev); TVMPlatformMemoryFree(func_registry_memory, dev); diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/crt/host/crt_config.h index 109abaf04083..b81a74eb4ae6 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/crt/host/crt_config.h @@ -43,10 +43,10 @@ #define TVM_CRT_MAX_REGISTERED_MODULES 2 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 /*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000 +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc index e64455417928..53a0572560e2 100644 --- a/src/runtime/crt/host/main.cc +++ b/src/runtime/crt/host/main.cc @@ -136,9 +136,12 @@ int main(int argc, char** argv) { "failed to register GraphExecutor TVMModule"); #endif - if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server, - 0)) { - fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n"); + int error = TVMFuncRegisterGlobal("tvm.testing.reset_server", + (TVMFunctionHandle)&testonly_reset_server, 0); + if (error) { + fprintf(stderr, + "utvm runtime: internal error (error#: %x) registering global packedfunc; exiting\n", + error); return 2; } diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index ace3e2bbb1b8..20d89ad52a1d 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -30,6 +30,9 @@ namespace runtime { /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; +// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered. +const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX; + /*! \brief The RPC code */ enum class RPCCode : int { kNone, diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 28f93f641b4b..9bb782b384dd 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -330,7 +330,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } /*! - * \brief Recive incoming packed seq from the stream. + * \brief Receive incoming packed seq from the stream. * \return The received argments. * \note The TVMArgs is available until we switchstate. */ @@ -369,7 +369,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { */ void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { TVMArgs args = RecvPackedSeq(); - if (code == RPCCode::kException) { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); @@ -802,14 +801,13 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) std::lock_guard lock(mutex_); RPCCode code = RPCCode::kCopyToRemote; - uint64_t num_data_bytes = static_cast(GetDataSize(*to)); - ICHECK_EQ(nbytes, num_data_bytes); + uint64_t tensor_total_size_bytes = static_cast(GetDataSize(*to)); + ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes) + << "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset + << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")"; - uint64_t to_data = reinterpret_cast(to->data); - uint64_t shape_bytes = to->ndim * sizeof(int64_t); - uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) + - sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes + - sizeof(nbytes) + num_data_bytes; + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes); + uint64_t packet_nbytes = overhead + nbytes; handler_->Write(packet_nbytes); handler_->Write(code); @@ -823,14 +821,13 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes std::lock_guard lock(mutex_); RPCCode code = RPCCode::kCopyFromRemote; - uint64_t num_data_bytes = static_cast(GetDataSize(*from)); - CHECK_EQ(nbytes, num_data_bytes); + uint64_t tensor_total_size_bytes = static_cast(GetDataSize(*from)); + ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes) + << "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset + << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")"; - uint64_t from_data = reinterpret_cast(from->data); - uint64_t shape_bytes = from->ndim * sizeof(int64_t); - uint64_t packet_nbytes = sizeof(code) + sizeof(from_data) + sizeof(from->device) + - sizeof(from->ndim) + sizeof(from->dtype) + sizeof(from->byte_offset) + - shape_bytes + sizeof(nbytes); + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes); + uint64_t packet_nbytes = overhead; handler_->Write(packet_nbytes); handler_->Write(code); @@ -1009,11 +1006,55 @@ class RPCClientSession : public RPCSession, public DeviceAPI { } void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final { - endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes); + RPCCode code = RPCCode::kCopyToRemote; + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes); + uint64_t rpc_max_size = GetRPCMaxTransferSize(); + ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!"; + const uint64_t block_size = rpc_max_size - overhead; + uint64_t block_count = 0; + const uint64_t num_blocks = nbytes / block_size; + void* from_bytes; + + for (block_count = 0; block_count < num_blocks; block_count++) { + remote_to->byte_offset = block_count * block_size; + from_bytes = reinterpret_cast( + (reinterpret_cast(local_from_bytes) + block_count * block_size)); + endpoint_->CopyToRemote(from_bytes, remote_to, block_size); + } + + const uint64_t remainder_bytes = nbytes % block_size; + if (remainder_bytes != 0) { + remote_to->byte_offset = block_count * block_size; + from_bytes = reinterpret_cast( + (reinterpret_cast(local_from_bytes) + block_count * block_size)); + endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes); + } } void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final { - endpoint_->CopyFromRemote(remote_from, local_to_bytes, nbytes); + RPCCode code = RPCCode::kCopyFromRemote; + uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes); + uint64_t rpc_max_size = GetRPCMaxTransferSize(); + ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!"; + const uint64_t block_size = rpc_max_size - overhead; + uint64_t block_count = 0; + const uint64_t num_blocks = nbytes / block_size; + void* to_bytes; + + for (block_count = 0; block_count < num_blocks; block_count++) { + remote_from->byte_offset = block_count * block_size; + to_bytes = reinterpret_cast( + (reinterpret_cast(local_to_bytes) + block_count * block_size)); + endpoint_->CopyFromRemote(remote_from, to_bytes, block_size); + } + + const uint64_t remainder_bytes = nbytes % block_size; + if (remainder_bytes != 0) { + remote_from->byte_offset = block_count * block_size; + to_bytes = reinterpret_cast( + (reinterpret_cast(local_to_bytes) + block_count * block_size)); + endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes); + } } void FreeHandle(void* handle, int type_code) final { @@ -1082,12 +1123,43 @@ class RPCClientSession : public RPCSession, public DeviceAPI { bool IsLocalSession() const final { return false; } private: + uint64_t GetRPCMaxTransferSize() { + if (rpc_chunk_max_size_bytes_ > 0) { + return (uint64_t)rpc_chunk_max_size_bytes_; + } + + PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize"); + if (rpc_func == nullptr) { + rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault; + } else { + CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) { + // Use args[1] as return value, args[0] is tcode + // Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc + rpc_chunk_max_size_bytes_ = (int64_t)args[1]; + ICHECK_GT(rpc_chunk_max_size_bytes_, 0) + << "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_ + << ")"; + }); + } + return (uint64_t)rpc_chunk_max_size_bytes_; + } + std::shared_ptr endpoint_; + int64_t rpc_chunk_max_size_bytes_ = -1; }; std::shared_ptr CreateClientSession(std::shared_ptr endpoint) { return std::make_shared(endpoint); } +uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) { + uint64_t shape_bytes = tensor->ndim * sizeof(int64_t); + uint64_t to_data = reinterpret_cast(static_cast(tensor->data)); + uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) + + sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) + + shape_bytes + sizeof(nbytes); + return overhead; +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index cd3c9b2bec72..7c11a1aeac01 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -204,6 +204,16 @@ template inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { return syscall_remote_(static_cast(code), std::forward(args)...); } + +/*! + * \brief Calculates overhead size of a CopyToRemote packet. + * \param to DLTensor to copy. + * \param code RPCCode for this transfer. + * \param nbytes Number of bytes to transfer. + * \return The remote-copy packet overhead size. + */ +uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes); + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_