Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPP_RPC] allow user supplied work dir #7670

Merged
merged 6 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion apps/cpp_rpc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static const string kUsage =
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
"--key - The key used to identify the device type in tracker. Default=\"\"\n"
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
"--work-dir - Custom work directory. Default=\"\"\n"
"--silent - Whether to run in silent mode. Default=False\n"
"\n"
" Example\n"
Expand All @@ -70,6 +71,7 @@ static const string kUsage =
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \arg key The key used to identify the device type in tracker. Default=""
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \arg work_dir Custom work directory. Default=""
* \arg silent Whether run in silent mode. Default=False
*/
struct RpcServerArgs {
Expand All @@ -79,6 +81,7 @@ struct RpcServerArgs {
string tracker;
string key;
string custom_addr;
string work_dir;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
Expand All @@ -96,6 +99,7 @@ void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "tracker = " << args.tracker;
LOG(INFO) << "key = " << args.key;
LOG(INFO) << "custom_addr = " << args.custom_addr;
LOG(INFO) << "work_dir = " << args.work_dir;
LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
}

Expand Down Expand Up @@ -238,6 +242,10 @@ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
dmlc::InitLogging("--minloglevel=0");
}
#endif
const string work_dir = GetCmdOption(argc, argv, "--work-dir=");
if (!work_dir.empty()) {
args.work_dir = work_dir;
}
}

/*!
Expand Down Expand Up @@ -274,7 +282,7 @@ int RpcServer(int argc, char* argv[]) {
#endif

RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
args.silent);
args.work_dir, args.silent);
return 0;
}

Expand Down
35 changes: 20 additions & 15 deletions apps/cpp_rpc/rpc_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
#include <iostream>
#include <string>
#include <vector>

#include "../../src/support/utils.h"
#include "rpc_env.h"

Expand Down Expand Up @@ -85,25 +84,31 @@ void CleanDir(const std::string& dirname);
*/
std::string BuildSharedLibrary(std::string file_in);

RPCEnv::RPCEnv() {
RPCEnv::RPCEnv(const std::string& wd) {
if (wd != "") {
base_ = wd + "/.cache";
mkdir(wd.c_str(), 0777);
mkdir(base_.c_str(), 0777);
} else {
#if defined(ANDROID) || defined(__ANDROID__)
char cwd[PATH_MAX];
auto cmdline = fopen("/proc/self/cmdline", "r");
fread(cwd, 1, sizeof(cwd), cmdline);
fclose(cmdline);
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
char cwd[PATH_MAX];
auto cmdline = fopen("/proc/self/cmdline", "r");
fread(cwd, 1, sizeof(cwd), cmdline);
fclose(cmdline);
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
#elif !defined(_WIN32)
char cwd[PATH_MAX];
if (getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
}
char cwd[PATH_MAX];
if (getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
}
#else
base_ = "./rpc";
base_ = "./rpc";
#endif
mkdir(base_.c_str(), 0777);
}

mkdir(base_.c_str(), 0777);
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPath(args[0]);
});
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/rpc_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct RPCEnv {
/*!
* \brief Constructor Init The RPC Environment initialize function
*/
RPCEnv();
RPCEnv(const std::string& word_dir = "");
/*!
* \brief GetPath To get the workpath from packed function
* \param name The file name
Expand Down
21 changes: 12 additions & 9 deletions apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ class RPCServer {
* \brief Constructor.
*/
RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key,
std::string custom_addr)
std::string custom_addr, std::string work_dir)
: host_(std::move(host)),
port_(port),
my_port_(0),
port_end_(port_end),
tracker_addr_(std::move(tracker_addr)),
key_(std::move(key)),
custom_addr_(std::move(custom_addr)) {}
custom_addr_(std::move(custom_addr)),
work_dir_(std::move(work_dir)) {}

/*!
* \brief Destructor.
Expand Down Expand Up @@ -174,7 +175,7 @@ class RPCServer {
const pid_t worker_pid = fork();
if (worker_pid == 0) {
// Worker process
ServerLoopProc(conn, addr);
ServerLoopProc(conn, addr, work_dir_);
_exit(0);
}

Expand All @@ -201,7 +202,7 @@ class RPCServer {
} else {
auto pid = fork();
if (pid == 0) {
ServerLoopProc(conn, addr);
ServerLoopProc(conn, addr, work_dir_);
exit(0);
}
// Wait for the result
Expand Down Expand Up @@ -308,9 +309,10 @@ class RPCServer {
* \param sock The socket information
* \param addr The socket address information
*/
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr,
std::string work_dir) {
// Server loop
const auto env = RPCEnv();
const auto env = RPCEnv(work_dir);
RPCServerLoop(int(sock.sockfd));
LOG(INFO) << "Finish serving " << addr.AsString();
env.CleanUp();
Expand Down Expand Up @@ -339,6 +341,7 @@ class RPCServer {
std::string tracker_addr_;
std::string key_;
std::string custom_addr_;
std::string work_dir_;
support::TCPSocket listen_sock_;
support::TCPSocket tracker_sock_;
};
Expand Down Expand Up @@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
* silent mode. Default=True
*/
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
std::string key, std::string custom_addr, bool silent) {
std::string key, std::string custom_addr, std::string work_dir, bool silent) {
if (silent) {
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
// Start the rpc server
RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
std::move(custom_addr));
std::move(custom_addr), std::move(work_dir));
rpc.Start();
}

TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
});
} // namespace runtime
} // namespace tvm
3 changes: 2 additions & 1 deletion apps/cpp_rpc/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ void ServerLoopFromChild(SOCKET socket);
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param work_dir Custom work directory. Default=""
* \param silent Whether run in silent mode. Default=True
*/
void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099,
std::string tracker_addr = "", std::string key = "",
std::string custom_addr = "", bool silent = true);
std::string custom_addr = "", std::string work_dir = "", bool silent = true);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_