diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index b9cdc2cf82ad..89f3e7c6c7f8 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -48,8 +48,13 @@ class RPCWrappedFunc : public Object { // scan and check whether we need rewrite these arguments // to their remote variant. for (int i = 0; i < args.size(); ++i) { + if (args[i].IsObjectRef()) { + String str = args[i]; + type_codes[i] = kTVMStr; + values[i].v_str = str.c_str(); + continue; + } int tcode = type_codes[i]; - switch (tcode) { case kTVMDLTensorHandle: case kTVMNDArrayHandle: { diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index dfbb3c55227b..7f01f880cd3d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -86,6 +86,22 @@ def remotethrow(name): f2 = client.get_function("rpc.test.strcat") assert f2("abc", 11) == "abc:11" + +def test_rpc_runtime_string(): + if not tvm.runtime.enabled("rpc"): + return + @tvm.register_func("rpc.test.runtime_str_concat") + def strcat(x, y): + return x + y + + server = rpc.Server("localhost", key="x1") + client = rpc.connect(server.host, server.port, key="x1") + func = client.get_function("rpc.test.runtime_str_concat") + x = tvm.runtime.container.String("abc") + y = tvm.runtime.container.String("def") + assert str(func(x, y)) == "abcdef" + + def test_rpc_array(): if not tvm.runtime.enabled("rpc"): return