Skip to content

Commit

Permalink
Allow RPCWrappedFunc to rewrite runtime::String as std::string (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and trevor-m committed Jun 18, 2020
1 parent 841d0bf commit ee165d6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ class RPCWrappedFunc : public Object {
// scan and check whether we need rewrite these arguments
// to their remote variant.
for (int i = 0; i < args.size(); ++i) {
if (args[i].IsObjectRef<String>()) {
String str = args[i];
type_codes[i] = kTVMStr;
values[i].v_str = str.c_str();
continue;
}
int tcode = type_codes[i];

switch (tcode) {
case kTVMDLTensorHandle:
case kTVMNDArrayHandle: {
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def remotethrow(name):
f2 = client.get_function("rpc.test.strcat")
assert f2("abc", 11) == "abc:11"


def test_rpc_runtime_string():
if not tvm.runtime.enabled("rpc"):
return
@tvm.register_func("rpc.test.runtime_str_concat")
def strcat(x, y):
return x + y

server = rpc.Server("localhost", key="x1")
client = rpc.connect(server.host, server.port, key="x1")
func = client.get_function("rpc.test.runtime_str_concat")
x = tvm.runtime.container.String("abc")
y = tvm.runtime.container.String("def")
assert str(func(x, y)) == "abcdef"


def test_rpc_array():
if not tvm.runtime.enabled("rpc"):
return
Expand Down

0 comments on commit ee165d6

Please sign in to comment.