Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][RPC] Update RPC runtime to allow remote module as arg #4462

Merged
merged 2 commits into from
Dec 3, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,17 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
try:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.create")
else:
fcreate = get_global_func("tvm.graph_runtime_debug.create")
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)

ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
libmod = rpc_base._ModuleHandle(libmod)
try:
fcreate = ctx[0]._rpc_sess.get_function(
"tvm.graph_runtime_debug.remote_create"
)
except ValueError:
raise ValueError(
"Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in "
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)

Expand Down
7 changes: 3 additions & 4 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ def create(graph_json_str, libmod, ctx):
ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx)

if num_rpc_ctx == len(ctx):
hmod = rpc_base._ModuleHandle(libmod)
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.remote_create")
return GraphModule(fcreate(graph_json_str, hmod, *device_type_id))
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime.create")
else:
fcreate = get_global_func("tvm.graph_runtime.create")

fcreate = get_global_func("tvm.graph_runtime.create")
return GraphModule(fcreate(graph_json_str, libmod, *device_type_id))

def get_device_ctx(libmod, ctx):
Expand Down
15 changes: 0 additions & 15 deletions src/runtime/graph/debug/graph_runtime_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <chrono>
#include <sstream>
#include "../graph_runtime.h"
#include "../../object_internal.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -220,19 +219,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
<< args.num_args;
*rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args));
});

TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeDebugCreate(
args[0], GetRef<Module>(mnode), contexts);
});

} // namespace runtime
} // namespace tvm
15 changes: 0 additions & 15 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include <vector>

#include "graph_runtime.h"
#include "../object_internal.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -511,19 +510,5 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(args[0], args[1], contexts);
});

TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);

const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(
args[0], GetRef<Module>(mnode), contexts);
});
} // namespace runtime
} // namespace tvm
24 changes: 23 additions & 1 deletion src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RPCWrappedFunc {
}

void operator()(TVMArgs args, TVMRetValue *rv) const {
sess_->CallFunc(handle_, args, rv, &fwrap_);
sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_);
}
~RPCWrappedFunc() {
try {
Expand All @@ -55,6 +55,9 @@ class RPCWrappedFunc {
TVMArgs args,
TVMRetValue* rv);

static void* UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg);

// deleter of RPC remote array
static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
Expand Down Expand Up @@ -181,6 +184,25 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc fwrap_;
};

void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg) {
if (arg.type_code() == kModuleHandle) {
Module mod = arg;
std::string tkey = mod->type_key();
CHECK_EQ(tkey, "rpc")
<< "ValueError: Cannot pass a non-RPC module to remote";
auto* rmod = static_cast<RPCModuleNode*>(mod.operator->());
CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index)
<< "ValueError: Cannot pass in module into a different remote session";
return rmod->module_handle();
} else {
LOG(FATAL) << "ValueError: Cannot pass type "
<< runtime::TypeCode2Str(arg.type_code())
<< " as an argument to the remote";
return nullptr;
}
}

void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
TVMArgs args,
TVMRetValue *rv) {
Expand Down
64 changes: 47 additions & 17 deletions src/runtime/rpc/rpc_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,33 @@ class RPCSession::EventHandler : public dmlc::Stream {
return ctx;
}
// Send Packed sequence to writer.
//
// client_mode: whether we are in client mode.
//
// funwrap: auxiliary function to unwrap remote Object
// when it is provided, we need to unwrap objects.
//
// return_ndarray is a special flag to handle returning of ndarray
// In this case, we return the shape, context and data of the array,
// as well as a customized PackedFunc that handles deletion of
// the array in the remote.
void SendPackedSeq(const TVMValue* arg_values,
const int* type_codes,
int n,
int num_args,
bool client_mode,
FUnwrapRemoteObject funwrap = nullptr,
bool return_ndarray = false) {
this->Write(n);
for (int i = 0; i < n; ++i) {
std::swap(client_mode_, client_mode);

this->Write(num_args);
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i];
if (tcode == kNDArrayContainer) tcode = kArrayHandle;
this->Write(tcode);
}

// Argument packing.
for (int i = 0; i < n; ++i) {
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i];
TVMValue value = arg_values[i];
switch (tcode) {
Expand All @@ -241,7 +251,23 @@ class RPCSession::EventHandler : public dmlc::Stream {
break;
}
case kFuncHandle:
case kModuleHandle:
case kModuleHandle: {
// always send handle in 64 bit.
uint64_t handle;
// allow pass module as argument to remote.
if (funwrap != nullptr) {
void* remote_handle = (*funwrap)(
rpc_sess_table_index_,
runtime::TVMArgValue(value, tcode));
handle = reinterpret_cast<uint64_t>(remote_handle);
} else {
CHECK(!client_mode_)
<< "Cannot directly pass remote object as argument";
handle = reinterpret_cast<uint64_t>(value.v_handle);
}
this->Write(handle);
break;
}
case kHandle: {
// always send handle in 64 bit.
uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
Expand Down Expand Up @@ -300,6 +326,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
}
std::swap(client_mode_, client_mode);
}

// Endian aware IO handling
Expand Down Expand Up @@ -430,11 +457,11 @@ class RPCSession::EventHandler : public dmlc::Stream {
case kHandle:
case kStr:
case kBytes:
case kModuleHandle:
case kTVMContext: {
this->RequestBytes(sizeof(TVMValue)); break;
}
case kFuncHandle:
case kModuleHandle: {
case kFuncHandle: {
CHECK(client_mode_)
<< "Only client can receive remote functions";
this->RequestBytes(sizeof(TVMValue)); break;
Expand Down Expand Up @@ -656,7 +683,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
TVMValue ret_value;
ret_value.v_str = e.what();
int ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
}
}
this->SwitchToState(kRecvCode);
Expand Down Expand Up @@ -711,7 +738,7 @@ class RPCSession::EventHandler : public dmlc::Stream {
}
}
this->Write(code);
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
arg_recv_stage_ = 0;
this->SwitchToState(kRecvCode);
}
Expand All @@ -734,22 +761,22 @@ class RPCSession::EventHandler : public dmlc::Stream {
if (rv.type_code() == kStr) {
ret_value.v_str = rv.ptr<std::string>()->c_str();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kBytes) {
std::string* bytes = rv.ptr<std::string>();
TVMByteArray arr;
arr.data = bytes->c_str();
arr.size = bytes->length();
ret_value.v_handle = &arr;
ret_tcode = kBytes;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kFuncHandle ||
rv.type_code() == kModuleHandle) {
// always send handle in 64 bit.
CHECK(!client_mode_)
<< "Only server can send function and module handle back.";
rv.MoveToCHost(&ret_value, &ret_tcode);
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
} else if (rv.type_code() == kNDArrayContainer) {
// always send handle in 64 bit.
CHECK(!client_mode_)
Expand All @@ -764,18 +791,18 @@ class RPCSession::EventHandler : public dmlc::Stream {
NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
ret_value_pack[1].v_handle = nd;
ret_tcode_pack[1] = kHandle;
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true);
} else {
ret_value = rv.value();
ret_tcode = rv.type_code();
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
}
} catch (const std::runtime_error& e) {
RPCCode code = RPCCode::kException;
this->Write(code);
ret_value.v_str = e.what();
ret_tcode = kStr;
SendPackedSeq(&ret_value, &ret_tcode, 1);
SendPackedSeq(&ret_value, &ret_tcode, 1, false);
}
}

Expand Down Expand Up @@ -873,7 +900,7 @@ void RPCSession::Init() {
&reader_, &writer_, table_index_, name_, &remote_key_);
// Quick function to call remote.
call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true);
RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
});
Expand Down Expand Up @@ -954,13 +981,16 @@ int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
void RPCSession::CallFunc(void* h,
TVMArgs args,
TVMRetValue* rv,
FUnwrapRemoteObject funwrap,
const PackedFunc* fwrap) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

RPCCode code = RPCCode::kCallFunc;
handler_->Write(code);
uint64_t handle = reinterpret_cast<uint64_t>(h);
handler_->Write(handle);
handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
handler_->SendPackedSeq(
args.values, args.type_codes, args.num_args, true, funwrap);
code = HandleUntilReturnEvent(rv, true, fwrap);
CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
}
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/rpc/rpc_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ enum class RPCCode : int {
kNDArrayFree
};

/*!
* \brief Function that unwraps a remote object to its handle.
* \param rpc_sess_table_index RPC session table index for validation.
* \param obj Handle to the object argument.
* \return The corresponding handle.
*/
typedef void* (*FUnwrapRemoteObject)(
int rpc_sess_table_index,
const TVMArgValue& obj);

/*!
* \brief Abstract channel interface used to create RPCSession.
*/
Expand Down Expand Up @@ -144,11 +154,13 @@ class RPCSession {
* \param handle The function handle
* \param args The arguments
* \param rv The return value.
* \param funpwrap Function that takes a remote object and returns the raw handle.
* \param fwrap Wrapper function to turn Function/Module handle into real return.
*/
void CallFunc(RPCFuncHandle handle,
TVMArgs args,
TVMRetValue* rv,
FUnwrapRemoteObject funwrap,
const PackedFunc* fwrap);
/*!
* \brief Copy bytes into remote array content.
Expand Down