Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows Support for cpp_rpc #4857

Merged
merged 11 commits into from
Apr 15, 2020
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)

if(USE_CPP_RPC AND UNIX)
message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.")
endif()

# include directories
include_directories(${CMAKE_INCLUDE_PATH})
include_directories("include")
Expand Down Expand Up @@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS})

if(USE_CPP_RPC)
add_subdirectory("apps/cpp_rpc")
endif()

if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...")
Expand Down
27 changes: 27 additions & 0 deletions apps/cpp_rpc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
set(TVM_RPC_SOURCES
main.cc
rpc_env.cc
rpc_server.cc
)

if(WIN32)
list(APPEND TVM_RPC_SOURCES win32_process.cc)
endif()

# Set output to same directory as the other TVM libs
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
add_executable(tvm_rpc ${TVM_RPC_SOURCES})
set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)

if(WIN32)
target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX)
endif()

target_include_directories(
tvm_rpc
PUBLIC "../../include"
PUBLIC DLPACK_PATH
PUBLIC DMLC_PATH
)

target_link_libraries(tvm_rpc tvm_runtime)
10 changes: 8 additions & 2 deletions apps/cpp_rpc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# TVM RPC Server
This folder contains a simple recipe to make RPC server in c++.

## Usage
## Usage (Non-Windows)
- Build tvm runtime
- Make the rpc executable [Makefile](Makefile).
`make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
Expand All @@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++.
```
- Use `./tvm_rpc server` to start the RPC server

## Usage (Windows)
- Build tvm with the argument -DUSE_CPP_RPC
- Install [LLVM pre-build binaries](https://releases.llvm.org/download.html), making sure to select the option to add it to the PATH.
- Verify Python 3.6 or newer is installed and in the PATH.
- Use `<tmv_output_dir>\tvm_rpc.exe` to start the RPC server

## How it works
- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.

Expand All @@ -53,4 +59,4 @@ Command line usage
```

## Note
Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently.
Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently.
95 changes: 68 additions & 27 deletions apps/cpp_rpc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
#include <stdlib.h>
#include <signal.h>
#include <stdio.h>
#include <cstdlib>
#include <csignal>
#include <cstdio>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h>
#endif
#include <dmlc/logging.h>
#include <iostream>
#include <cstring>
Expand All @@ -35,11 +37,15 @@
#include "../../src/support/socket.h"
#include "rpc_server.h"

#if defined(_WIN32)
#include "win32_process.h"
#endif

using namespace std;
using namespace tvm::runtime;
using namespace tvm::support;

static const string kUSAGE = \
static const string kUsage = \
"Command line usage\n" \
" server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \
Expand Down Expand Up @@ -73,13 +79,16 @@ struct RpcServerArgs {
string key;
string custom_addr;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
#endif
};

/*!
* \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure
*/
void PrintArgs(struct RpcServerArgs args) {
void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end;
Expand All @@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
}

#if defined(__linux__) || defined(__ANDROID__)
/*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal
Expand All @@ -109,7 +119,7 @@ void HandleCtrlC() {
sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr);
}

#endif
/*!
* \brief GetCmdOption Parse and find the command option.
* \param argc arg counter
Expand All @@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
}
// We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find("=") + 1);
cmd = arg.substr(arg.find('=') + 1);
return cmd;
}
}
Expand All @@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) {
* \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter
* \param argv arg values
* \param args, the output structure which holds the parsed values
* \param args the output structure which holds the parsed values
*/
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
string silent = GetCmdOption(argc, argv, "--silent", true);
const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}

string host = GetCmdOption(argc, argv, "--host=");
const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) {
if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.host = host;
}

string port = GetCmdOption(argc, argv, "--port=");
const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.port = stoi(port);
}

string port_end = GetCmdOption(argc, argv, "--port_end=");
const string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.port_end = stoi(port_end);
Expand All @@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
if (!tracker.empty()) {
if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.tracker = tracker;
}

string key = GetCmdOption(argc, argv, "--key=");
const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) {
args.key = key;
}

string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
const string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.custom_addr = custom_addr;
}
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
if(!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif

}

/*!
Expand All @@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
* \return result of operation.
*/
int RpcServer(int argc, char * argv[]) {
struct RpcServerArgs args;
RpcServerArgs args;

/* parse the command line args */
ParseCmdArgs(argc, argv, args);
PrintArgs(args);

// Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
#if defined(__linux__) || defined(__ANDROID__)
// Ctrl+C handler
HandleCtrlC();
tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
#endif

#if defined(WIN32)
if(!args.mmap_path.empty()) {
int ret = 0;

try {
ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
ret = -1;
}

return ret;
}
#endif

RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
return 0;
}

Expand All @@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) {
*/
int main(int argc, char * argv[]) {
if (argc <= 1) {
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
return 0;
}

// Runs WSAStartup on Win32, no-op on POSIX
Socket::Startup();
#if defined(_WIN32)
SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1");
#endif

if (0 == strcmp(argv[1], "server")) {
RpcServer(argc, argv);
} else {
LOG(INFO) << kUSAGE;
return RpcServer(argc, argv);
}

LOG(INFO) << kUsage;

return 0;
}
Loading